bf4db7b7a2bd4130c35499b15d505ddbd871c1a5
[alexxy/gromacs.git] / src / gromacs / ewald / pme-solve.cu
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2016,2017, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief Implements PME GPU Fourier grid solving in CUDA.
38  *
39  *  \author Aleksei Iupinov <a.yupinov@gmail.com>
40  */
41
42 #include "gmxpre.h"
43
44 #include "config.h"
45
46 #include "gromacs/gpu_utils/cuda_arch_utils.cuh"
47 #include "gromacs/gpu_utils/cudautils.cuh"
48 #include "gromacs/utility/exceptions.h"
49 #include "gromacs/utility/gmxassert.h"
50
51 #include "pme.cuh"
52 #include "pme-timings.cuh"
53
54 //! Solving kernel max block width in warps picked among powers of 2 (2, 4, 8, 16) for max. occupancy and min. runtime
55 //! (560Ti (CC2.1), 660Ti (CC3.0) and 750 (CC5.0)))
56 constexpr int c_solveMaxWarpsPerBlock = 8;
57 //! Solving kernel max block size in threads
58 constexpr int c_solveMaxThreadsPerBlock = (c_solveMaxWarpsPerBlock * warp_size);
59
60 // CUDA 6.5 can not compile enum class as a template kernel parameter,
61 // so we replace it with a duplicate simple enum
62 #if GMX_CUDA_VERSION >= 7000
63 using GridOrderingInternal = GridOrdering;
64 #else
65 enum GridOrderingInternal
66 {
67     YZX,
68     XYZ
69 };
70 #endif
71
72 /*! \brief
73  * PME complex grid solver kernel function.
74  *
75  * \tparam[in] gridOrdering             Specifies the dimension ordering of the complex grid.
76  * \tparam[in] computeEnergyAndVirial   Tells if the reciprocal energy and virial should be computed.
77  * \param[in]  kernelParams             Input PME CUDA data in constant memory.
78  */
79 template<
80     GridOrderingInternal gridOrdering,
81     bool computeEnergyAndVirial
82     >
83 __launch_bounds__(c_solveMaxThreadsPerBlock)
84 __global__ void pme_solve_kernel(const struct PmeGpuCudaKernelParams kernelParams)
85 {
86     /* This kernel supports 2 different grid dimension orderings: YZX and XYZ */
87     int majorDim, middleDim, minorDim;
88     switch (gridOrdering)
89     {
90         case GridOrderingInternal::YZX:
91             majorDim  = YY;
92             middleDim = ZZ;
93             minorDim  = XX;
94             break;
95
96         case GridOrderingInternal::XYZ:
97             majorDim  = XX;
98             middleDim = YY;
99             minorDim  = ZZ;
100             break;
101
102         default:
103             assert(false);
104     }
105
106     /* Global memory pointers */
107     const float * __restrict__ gm_splineValueMajor    = kernelParams.grid.d_splineModuli + kernelParams.grid.splineValuesOffset[majorDim];
108     const float * __restrict__ gm_splineValueMiddle   = kernelParams.grid.d_splineModuli + kernelParams.grid.splineValuesOffset[middleDim];
109     const float * __restrict__ gm_splineValueMinor    = kernelParams.grid.d_splineModuli + kernelParams.grid.splineValuesOffset[minorDim];
110     float * __restrict__       gm_virialAndEnergy     = kernelParams.constants.d_virialAndEnergy;
111     float2 * __restrict__      gm_grid                = (float2 *)kernelParams.grid.d_fourierGrid;
112
113     /* Various grid sizes and indices */
114     const int localOffsetMinor = 0, localOffsetMajor = 0, localOffsetMiddle = 0; //unused
115     const int localSizeMinor   = kernelParams.grid.complexGridSizePadded[minorDim];
116     const int localSizeMiddle  = kernelParams.grid.complexGridSizePadded[middleDim];
117     const int localCountMiddle = kernelParams.grid.complexGridSize[middleDim];
118     const int localCountMinor  = kernelParams.grid.complexGridSize[minorDim];
119     const int nMajor           = kernelParams.grid.realGridSize[majorDim];
120     const int nMiddle          = kernelParams.grid.realGridSize[middleDim];
121     const int nMinor           = kernelParams.grid.realGridSize[minorDim];
122     const int maxkMajor        = (nMajor + 1) / 2;  // X or Y
123     const int maxkMiddle       = (nMiddle + 1) / 2; // Y OR Z => only check for !YZX
124     const int maxkMinor        = (nMinor + 1) / 2;  // Z or X => only check for YZX
125
126     /* Each thread works on one cell of the Fourier space complex 3D grid (gm_grid).
127      * Each block handles up to c_solveMaxThreadsPerBlock cells -
128      * depending on the grid contiguous dimension size,
129      * that can range from a part of a single gridline to several complete gridlines.
130      */
131     const int threadLocalId     = threadIdx.x;
132     const int gridLineSize      = localCountMinor;
133     const int gridLineIndex     = threadLocalId / gridLineSize;
134     const int gridLineCellIndex = threadLocalId - gridLineSize * gridLineIndex;
135     const int gridLinesPerBlock = blockDim.x / gridLineSize;
136     const int activeWarps       = (blockDim.x / warp_size);
137     const int indexMinor        = blockIdx.x * blockDim.x + gridLineCellIndex;
138     const int indexMiddle       = blockIdx.y * gridLinesPerBlock + gridLineIndex;
139     const int indexMajor        = blockIdx.z;
140
141     /* Optional outputs */
142     float energy = 0.0f;
143     float virxx  = 0.0f;
144     float virxy  = 0.0f;
145     float virxz  = 0.0f;
146     float viryy  = 0.0f;
147     float viryz  = 0.0f;
148     float virzz  = 0.0f;
149
150     assert(indexMajor < kernelParams.grid.complexGridSize[majorDim]);
151     if ((indexMiddle < localCountMiddle) & (indexMinor < localCountMinor) & (gridLineIndex < gridLinesPerBlock))
152     {
153         /* The offset should be equal to the global thread index for coalesced access */
154         const int             gridIndex     = (indexMajor * localSizeMiddle + indexMiddle) * localSizeMinor + indexMinor;
155         float2 * __restrict__ gm_gridCell   = gm_grid + gridIndex;
156
157         const int             kMajor  = indexMajor + localOffsetMajor;
158         /* Checking either X in XYZ, or Y in YZX cases */
159         const float           mMajor  = (kMajor < maxkMajor) ? kMajor : (kMajor - nMajor);
160
161         const int             kMiddle = indexMiddle + localOffsetMiddle;
162         float                 mMiddle = kMiddle;
163         /* Checking Y in XYZ case */
164         if (gridOrdering == GridOrderingInternal::XYZ)
165         {
166             mMiddle = (kMiddle < maxkMiddle) ? kMiddle : (kMiddle - nMiddle);
167         }
168         const int             kMinor  = localOffsetMinor + indexMinor;
169         float                 mMinor  = kMinor;
170         /* Checking X in YZX case */
171         if (gridOrdering == GridOrderingInternal::YZX)
172         {
173             mMinor = (kMinor < maxkMinor) ? kMinor : (kMinor - nMinor);
174         }
175         /* We should skip the k-space point (0,0,0) */
176         const bool notZeroPoint  = (kMinor > 0) | (kMajor > 0) | (kMiddle > 0);
177
178         float      mX, mY, mZ;
179         switch (gridOrdering)
180         {
181             case GridOrderingInternal::YZX:
182                 mX = mMinor;
183                 mY = mMajor;
184                 mZ = mMiddle;
185                 break;
186
187             case GridOrderingInternal::XYZ:
188                 mX = mMajor;
189                 mY = mMiddle;
190                 mZ = mMinor;
191                 break;
192
193             default:
194                 assert(false);
195         }
196
197         /* 0.5 correction factor for the first and last components of a Z dimension */
198         float corner_fac = 1.0f;
199         switch (gridOrdering)
200         {
201             case GridOrderingInternal::YZX:
202                 if ((kMiddle == 0) | (kMiddle == maxkMiddle))
203                 {
204                     corner_fac = 0.5f;
205                 }
206                 break;
207
208             case GridOrderingInternal::XYZ:
209                 if ((kMinor == 0) | (kMinor == maxkMinor))
210                 {
211                     corner_fac = 0.5f;
212                 }
213                 break;
214
215             default:
216                 assert(false);
217         }
218
219         if (notZeroPoint)
220         {
221             const float mhxk = mX * kernelParams.current.recipBox[XX][XX];
222             const float mhyk = mX * kernelParams.current.recipBox[XX][YY] + mY * kernelParams.current.recipBox[YY][YY];
223             const float mhzk = mX * kernelParams.current.recipBox[XX][ZZ] + mY * kernelParams.current.recipBox[YY][ZZ] + mZ * kernelParams.current.recipBox[ZZ][ZZ];
224
225             const float m2k        = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
226             assert(m2k != 0.0f);
227             //TODO: use LDG/textures for gm_splineValue
228             float       denom = m2k * float(M_PI) * kernelParams.current.boxVolume * gm_splineValueMajor[kMajor] * gm_splineValueMiddle[kMiddle] * gm_splineValueMinor[kMinor];
229             assert(isfinite(denom));
230             assert(denom != 0.0f);
231             const float   tmp1   = expf(-kernelParams.grid.ewaldFactor * m2k);
232             const float   etermk = kernelParams.constants.elFactor * tmp1 / denom;
233
234             float2        gridValue    = *gm_gridCell;
235             const float2  oldGridValue = gridValue;
236             gridValue.x   *= etermk;
237             gridValue.y   *= etermk;
238             *gm_gridCell   = gridValue;
239
240             if (computeEnergyAndVirial)
241             {
242                 const float tmp1k = 2.0f * (gridValue.x * oldGridValue.x + gridValue.y * oldGridValue.y);
243
244                 float       vfactor = (kernelParams.grid.ewaldFactor + 1.0f / m2k) * 2.0f;
245                 float       ets2    = corner_fac * tmp1k;
246                 energy = ets2;
247
248                 float ets2vf  = ets2 * vfactor;
249
250                 virxx   = ets2vf * mhxk * mhxk - ets2;
251                 virxy   = ets2vf * mhxk * mhyk;
252                 virxz   = ets2vf * mhxk * mhzk;
253                 viryy   = ets2vf * mhyk * mhyk - ets2;
254                 viryz   = ets2vf * mhyk * mhzk;
255                 virzz   = ets2vf * mhzk * mhzk - ets2;
256             }
257         }
258     }
259
260     /* Optional energy/virial reduction */
261     if (computeEnergyAndVirial)
262     {
263 #if (GMX_PTX_ARCH >= 300)
264         /* A tricky shuffle reduction inspired by reduce_force_j_warp_shfl.
265          * The idea is to reduce 7 energy/virial components into a single variable (aligned by 8).
266          * We will reduce everything into virxx.
267          */
268
269         /* We can only reduce warp-wise */
270         const int          width      = warp_size;
271         const unsigned int activeMask = c_fullWarpMask;
272
273         /* Making pair sums */
274         virxx  += gmx_shfl_down_sync(activeMask, virxx, 1, width);
275         viryy  += gmx_shfl_up_sync  (activeMask, viryy, 1, width);
276         virzz  += gmx_shfl_down_sync(activeMask, virzz, 1, width);
277         virxy  += gmx_shfl_up_sync  (activeMask, virxy, 1, width);
278         virxz  += gmx_shfl_down_sync(activeMask, virxz, 1, width);
279         viryz  += gmx_shfl_up_sync  (activeMask, viryz, 1, width);
280         energy += gmx_shfl_down_sync(activeMask, energy, 1, width);
281         if (threadLocalId & 1)
282         {
283             virxx = viryy; // virxx now holds virxx and viryy pair sums
284             virzz = virxy; // virzz now holds virzz and virxy pair sums
285             virxz = viryz; // virxz now holds virxz and viryz pair sums
286         }
287
288         /* Making quad sums */
289         virxx  += gmx_shfl_down_sync(activeMask, virxx, 2, width);
290         virzz  += gmx_shfl_up_sync  (activeMask, virzz, 2, width);
291         virxz  += gmx_shfl_down_sync(activeMask, virxz, 2, width);
292         energy += gmx_shfl_up_sync  (activeMask, energy, 2, width);
293         if (threadLocalId & 2)
294         {
295             virxx = virzz;  // virxx now holds quad sums of virxx, virxy, virzz and virxy
296             virxz = energy; // virxz now holds quad sums of virxz, viryz, energy and unused paddings
297         }
298
299         /* Making octet sums */
300         virxx += gmx_shfl_down_sync(activeMask, virxx, 4, width);
301         virxz += gmx_shfl_up_sync  (activeMask, virxz, 4, width);
302         if (threadLocalId & 4)
303         {
304             virxx = virxz; // virxx now holds all 7 components' octet sums + unused paddings
305         }
306
307         /* We only need to reduce virxx now */
308 #pragma unroll
309         for (int delta = 8; delta < width; delta <<= 1)
310         {
311             virxx += gmx_shfl_down_sync(activeMask, virxx, delta, width);
312         }
313         /* Now first 7 threads of each warp have the full output contributions in virxx */
314
315         const int        componentIndex      = threadLocalId & (warp_size - 1);
316         const bool       validComponentIndex = (componentIndex < c_virialAndEnergyCount);
317         /* Reduce 7 outputs per warp in the shared memory */
318         const int        stride              = 8; // this is c_virialAndEnergyCount==7 rounded up to power of 2 for convenience, hence the assert
319         assert(c_virialAndEnergyCount == 7);
320         const int        reductionBufferSize = (c_solveMaxThreadsPerBlock / warp_size) * stride;
321         __shared__ float sm_virialAndEnergy[reductionBufferSize];
322
323         if (validComponentIndex)
324         {
325             const int warpIndex = threadLocalId / warp_size;
326             sm_virialAndEnergy[warpIndex * stride + componentIndex] = virxx;
327         }
328         __syncthreads();
329
330         /* Reduce to the single warp size */
331         const int targetIndex = threadLocalId;
332 #pragma unroll
333         for (int reductionStride = reductionBufferSize >> 1; reductionStride >= warp_size; reductionStride >>= 1)
334         {
335             const int sourceIndex = targetIndex + reductionStride;
336             if ((targetIndex < reductionStride) & (sourceIndex < activeWarps * stride))
337             {
338                 // TODO: the second conditional is only needed on first iteration, actually - see if compiler eliminates it!
339                 sm_virialAndEnergy[targetIndex] += sm_virialAndEnergy[sourceIndex];
340             }
341             __syncthreads();
342         }
343
344         /* Now use shuffle again */
345         if (threadLocalId < warp_size)
346         {
347             float output = sm_virialAndEnergy[threadLocalId];
348 #pragma unroll
349             for (int delta = stride; delta < warp_size; delta <<= 1)
350             {
351                 output += gmx_shfl_down_sync(activeMask, output, delta, warp_size);
352             }
353             /* Final output */
354             if (validComponentIndex)
355             {
356                 assert(isfinite(output));
357                 atomicAdd(gm_virialAndEnergy + componentIndex, output);
358             }
359         }
360 #else
361         /* Shared memory reduction with atomics for compute capability < 3.0.
362          * Each component is first reduced into warp_size positions in the shared memory;
363          * Then first c_virialAndEnergyCount warps reduce everything further and add to the global memory.
364          * This can likely be improved, but is anyway faster than the previous straightforward reduction,
365          * which was using too much shared memory (for storing all 7 floats on each thread).
366          * [48KB (shared mem limit per SM on CC2.x) / sizeof(float) (4) / c_solveMaxThreadsPerBlock (256) / c_virialAndEnergyCount (7) ==
367          * 6 blocks per SM instead of 16 which is maximum on CC2.x].
368          */
369
370         const int        lane      = threadLocalId & (warp_size - 1);
371         const int        warpIndex = threadLocalId / warp_size;
372         const bool       firstWarp = (warpIndex == 0);
373         __shared__ float sm_virialAndEnergy[c_virialAndEnergyCount * warp_size];
374         if (firstWarp)
375         {
376             sm_virialAndEnergy[0 * warp_size + lane] = virxx;
377             sm_virialAndEnergy[1 * warp_size + lane] = viryy;
378             sm_virialAndEnergy[2 * warp_size + lane] = virzz;
379             sm_virialAndEnergy[3 * warp_size + lane] = virxy;
380             sm_virialAndEnergy[4 * warp_size + lane] = virxz;
381             sm_virialAndEnergy[5 * warp_size + lane] = viryz;
382             sm_virialAndEnergy[6 * warp_size + lane] = energy;
383         }
384         __syncthreads();
385         if (!firstWarp)
386         {
387             atomicAdd(sm_virialAndEnergy + 0 * warp_size + lane, virxx);
388             atomicAdd(sm_virialAndEnergy + 1 * warp_size + lane, viryy);
389             atomicAdd(sm_virialAndEnergy + 2 * warp_size + lane, virzz);
390             atomicAdd(sm_virialAndEnergy + 3 * warp_size + lane, virxy);
391             atomicAdd(sm_virialAndEnergy + 4 * warp_size + lane, virxz);
392             atomicAdd(sm_virialAndEnergy + 5 * warp_size + lane, viryz);
393             atomicAdd(sm_virialAndEnergy + 6 * warp_size + lane, energy);
394         }
395         __syncthreads();
396
397         GMX_UNUSED_VALUE(activeWarps);
398         assert(activeWarps >= c_virialAndEnergyCount); // we need to cover all components, or have multiple iterations otherwise
399         const int componentIndex = warpIndex;
400         if (componentIndex < c_virialAndEnergyCount)
401         {
402             const int targetIndex = threadLocalId;
403 #pragma unroll
404             for (int reductionStride = warp_size >> 1; reductionStride >= 1; reductionStride >>= 1)
405             {
406                 if (lane < reductionStride)
407                 {
408                     sm_virialAndEnergy[targetIndex] += sm_virialAndEnergy[targetIndex + reductionStride];
409                 }
410             }
411             if (lane == 0)
412             {
413                 atomicAdd(gm_virialAndEnergy + componentIndex, sm_virialAndEnergy[targetIndex]);
414             }
415         }
416 #endif
417     }
418 }
419
420 void pme_gpu_solve(const PmeGpu *pmeGpu, t_complex *h_grid,
421                    GridOrdering gridOrdering, bool computeEnergyAndVirial)
422 {
423     const bool   copyInputAndOutputGrid = pme_gpu_is_testing(pmeGpu) || !pme_gpu_performs_FFT(pmeGpu);
424
425     cudaStream_t stream          = pmeGpu->archSpecific->pmeStream;
426     const auto  *kernelParamsPtr = pmeGpu->kernelParams.get();
427
428     if (copyInputAndOutputGrid)
429     {
430         cu_copy_H2D_async(kernelParamsPtr->grid.d_fourierGrid, h_grid, pmeGpu->archSpecific->complexGridSize * sizeof(float), stream);
431     }
432
433     int majorDim = -1, middleDim = -1, minorDim = -1;
434     switch (gridOrdering)
435     {
436         case GridOrdering::YZX:
437             majorDim  = YY;
438             middleDim = ZZ;
439             minorDim  = XX;
440             break;
441
442         case GridOrdering::XYZ:
443             majorDim  = XX;
444             middleDim = YY;
445             minorDim  = ZZ;
446             break;
447
448         default:
449             GMX_ASSERT(false, "Implement grid ordering here and below for the kernel launch");
450     }
451
452     const int maxBlockSize      = c_solveMaxThreadsPerBlock;
453     const int gridLineSize      = pmeGpu->kernelParams->grid.complexGridSize[minorDim];
454     const int gridLinesPerBlock = std::max(maxBlockSize / gridLineSize, 1);
455     const int blocksPerGridLine = (gridLineSize + maxBlockSize - 1) / maxBlockSize;
456     const int cellsPerBlock     = gridLineSize * gridLinesPerBlock;
457     const int blockSize         = (cellsPerBlock + warp_size - 1) / warp_size * warp_size;
458     // rounding up to full warps so that shuffle operations produce defined results
459     dim3 threads(blockSize);
460     dim3 blocks(blocksPerGridLine,
461                 (pmeGpu->kernelParams->grid.complexGridSize[middleDim] + gridLinesPerBlock - 1) / gridLinesPerBlock,
462                 pmeGpu->kernelParams->grid.complexGridSize[majorDim]);
463
464     pme_gpu_start_timing(pmeGpu, gtPME_SOLVE);
465     if (gridOrdering == GridOrdering::YZX)
466     {
467         if (computeEnergyAndVirial)
468         {
469             pme_solve_kernel<GridOrderingInternal::YZX, true> <<< blocks, threads, 0, stream>>> (*kernelParamsPtr);
470         }
471         else
472         {
473             pme_solve_kernel<GridOrderingInternal::YZX, false> <<< blocks, threads, 0, stream>>> (*kernelParamsPtr);
474         }
475     }
476     else if (gridOrdering == GridOrdering::XYZ)
477     {
478         if (computeEnergyAndVirial)
479         {
480             pme_solve_kernel<GridOrderingInternal::XYZ, true> <<< blocks, threads, 0, stream>>> (*kernelParamsPtr);
481         }
482         else
483         {
484             pme_solve_kernel<GridOrderingInternal::XYZ, false> <<< blocks, threads, 0, stream>>> (*kernelParamsPtr);
485         }
486     }
487     CU_LAUNCH_ERR("pme_solve_kernel");
488     pme_gpu_stop_timing(pmeGpu, gtPME_SOLVE);
489
490     if (computeEnergyAndVirial)
491     {
492         cu_copy_D2H_async(pmeGpu->staging.h_virialAndEnergy, kernelParamsPtr->constants.d_virialAndEnergy,
493                           c_virialAndEnergyCount * sizeof(float), stream);
494     }
495
496     if (copyInputAndOutputGrid)
497     {
498         cu_copy_D2H_async(h_grid, kernelParamsPtr->grid.d_fourierGrid, pmeGpu->archSpecific->complexGridSize * sizeof(float), stream);
499     }
500 }