d63901d3197d2556785d9b1721e08dcd982c441f
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gpu_3dfft_ocl.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2016,2017,2018,2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief Implements OpenCL 3D FFT routines for PME GPU.
38  *
39  *  \author Aleksei Iupinov <a.yupinov@gmail.com>
40  *  \ingroup module_ewald
41  */
42
43 #include "gmxpre.h"
44
45 #include "pme_gpu_3dfft.h"
46
47 #include <array>
48 #include <vector>
49
50 #include <clFFT.h>
51
52 #include "gromacs/gpu_utils/device_context.h"
53 #include "gromacs/gpu_utils/device_stream.h"
54 #include "gromacs/gpu_utils/gmxopencl.h"
55 #include "gromacs/utility/exceptions.h"
56 #include "gromacs/utility/gmxassert.h"
57 #include "gromacs/utility/stringutil.h"
58
59 class GpuParallel3dFft::Impl
60 {
61 public:
62     Impl(ivec                 realGridSize,
63          ivec                 realGridSizePadded,
64          ivec                 complexGridSizePadded,
65          bool                 useDecomposition,
66          bool                 performOutOfPlaceFFT,
67          const DeviceContext& context,
68          const DeviceStream&  pmeStream,
69          DeviceBuffer<float>  realGrid,
70          DeviceBuffer<float>  complexGrid);
71     ~Impl();
72
73     clfftPlanHandle               planR2C_;
74     clfftPlanHandle               planC2R_;
75     std::vector<cl_command_queue> commandStreams_;
76     cl_mem                        realGrid_;
77     cl_mem                        complexGrid_;
78 };
79
80 //! Throws the exception on clFFT error
81 static void handleClfftError(clfftStatus status, const char* msg)
82 {
83     // Supposedly it's just a superset of standard OpenCL errors
84     if (status != CLFFT_SUCCESS)
85     {
86         GMX_THROW(gmx::InternalError(gmx::formatString("%s: %d", msg, status)));
87     }
88 }
89
90 GpuParallel3dFft::Impl::Impl(ivec                 realGridSize,
91                              ivec                 realGridSizePadded,
92                              ivec                 complexGridSizePadded,
93                              const bool           useDecomposition,
94                              const bool           performOutOfPlaceFFT,
95                              const DeviceContext& context,
96                              const DeviceStream&  pmeStream,
97                              DeviceBuffer<float>  realGrid,
98                              DeviceBuffer<float>  complexGrid) :
99     realGrid_(realGrid), complexGrid_(complexGrid)
100 {
101     GMX_RELEASE_ASSERT(!useDecomposition, "FFT decomposition not implemented");
102
103     cl_context clContext = context.context();
104     commandStreams_.push_back(pmeStream.stream());
105
106     // clFFT expects row-major, so dimensions/strides are reversed (ZYX instead of XYZ)
107     std::array<size_t, DIM> realGridDimensions = { size_t(realGridSize[ZZ]),
108                                                    size_t(realGridSize[YY]),
109                                                    size_t(realGridSize[XX]) };
110     std::array<size_t, DIM> realGridStrides    = {
111         1, size_t(realGridSizePadded[ZZ]), size_t(realGridSizePadded[YY] * realGridSizePadded[ZZ])
112     };
113     std::array<size_t, DIM> complexGridStrides = {
114         1, size_t(complexGridSizePadded[ZZ]), size_t(complexGridSizePadded[YY] * complexGridSizePadded[ZZ])
115     };
116
117     constexpr clfftDim dims = CLFFT_3D;
118     handleClfftError(clfftCreateDefaultPlan(&planR2C_, clContext, dims, realGridDimensions.data()),
119                      "clFFT planning failure");
120     handleClfftError(clfftSetResultLocation(planR2C_, performOutOfPlaceFFT ? CLFFT_OUTOFPLACE : CLFFT_INPLACE),
121                      "clFFT planning failure");
122     handleClfftError(clfftSetPlanPrecision(planR2C_, CLFFT_SINGLE), "clFFT planning failure");
123     constexpr cl_float scale = 1.0;
124     handleClfftError(clfftSetPlanScale(planR2C_, CLFFT_FORWARD, scale),
125                      "clFFT coefficient setup failure");
126     handleClfftError(clfftSetPlanScale(planR2C_, CLFFT_BACKWARD, scale),
127                      "clFFT coefficient setup failure");
128
129     // The only difference between 2 plans is direction
130     handleClfftError(clfftCopyPlan(&planC2R_, clContext, planR2C_), "clFFT plan copying failure");
131
132     handleClfftError(clfftSetLayout(planR2C_, CLFFT_REAL, CLFFT_HERMITIAN_INTERLEAVED),
133                      "clFFT R2C layout failure");
134     handleClfftError(clfftSetLayout(planC2R_, CLFFT_HERMITIAN_INTERLEAVED, CLFFT_REAL),
135                      "clFFT C2R layout failure");
136
137     handleClfftError(clfftSetPlanInStride(planR2C_, dims, realGridStrides.data()),
138                      "clFFT stride setting failure");
139     handleClfftError(clfftSetPlanOutStride(planR2C_, dims, complexGridStrides.data()),
140                      "clFFT stride setting failure");
141
142     handleClfftError(clfftSetPlanInStride(planC2R_, dims, complexGridStrides.data()),
143                      "clFFT stride setting failure");
144     handleClfftError(clfftSetPlanOutStride(planC2R_, dims, realGridStrides.data()),
145                      "clFFT stride setting failure");
146
147     handleClfftError(clfftBakePlan(planR2C_, commandStreams_.size(), commandStreams_.data(), nullptr, nullptr),
148                      "clFFT precompiling failure");
149     handleClfftError(clfftBakePlan(planC2R_, commandStreams_.size(), commandStreams_.data(), nullptr, nullptr),
150                      "clFFT precompiling failure");
151
152     // TODO: implement solve kernel as R2C FFT callback
153     // TODO: disable last transpose (clfftSetPlanTransposeResult)
154 }
155
156 GpuParallel3dFft::Impl::~Impl()
157 {
158     clfftDestroyPlan(&planR2C_);
159     clfftDestroyPlan(&planC2R_);
160 }
161
162 void GpuParallel3dFft::perform3dFft(gmx_fft_direction dir, CommandEvent* timingEvent)
163 {
164     cl_mem                            tempBuffer = nullptr;
165     constexpr std::array<cl_event, 0> waitEvents{ {} };
166
167     clfftPlanHandle plan;
168     clfftDirection  direction;
169     cl_mem *        inputGrids, *outputGrids;
170
171     switch (dir)
172     {
173         case GMX_FFT_REAL_TO_COMPLEX:
174             plan        = impl_->planR2C_;
175             direction   = CLFFT_FORWARD;
176             inputGrids  = &impl_->realGrid_;
177             outputGrids = &impl_->complexGrid_;
178             break;
179         case GMX_FFT_COMPLEX_TO_REAL:
180             plan        = impl_->planC2R_;
181             direction   = CLFFT_BACKWARD;
182             inputGrids  = &impl_->complexGrid_;
183             outputGrids = &impl_->realGrid_;
184             break;
185         default:
186             GMX_THROW(
187                     gmx::NotImplementedError("The chosen 3D-FFT case is not implemented on GPUs"));
188     }
189     handleClfftError(clfftEnqueueTransform(plan,
190                                            direction,
191                                            impl_->commandStreams_.size(),
192                                            impl_->commandStreams_.data(),
193                                            waitEvents.size(),
194                                            waitEvents.data(),
195                                            timingEvent,
196                                            inputGrids,
197                                            outputGrids,
198                                            tempBuffer),
199                      "clFFT execution failure");
200 }
201
202 GpuParallel3dFft::GpuParallel3dFft(ivec                 realGridSize,
203                                    ivec                 realGridSizePadded,
204                                    ivec                 complexGridSizePadded,
205                                    const bool           useDecomposition,
206                                    const bool           performOutOfPlaceFFT,
207                                    const DeviceContext& context,
208                                    const DeviceStream&  pmeStream,
209                                    DeviceBuffer<float>  realGrid,
210                                    DeviceBuffer<float>  complexGrid) :
211     impl_(std::make_unique<Impl>(realGridSize,
212                                  realGridSizePadded,
213                                  complexGridSizePadded,
214                                  useDecomposition,
215                                  performOutOfPlaceFFT,
216                                  context,
217                                  pmeStream,
218                                  realGrid,
219                                  complexGrid))
220 {
221 }
222
223 GpuParallel3dFft::~GpuParallel3dFft() = default;