f89d1747c2305497547433d93e0d99b9631ac2af
[alexxy/gromacs.git] / src / gromacs / fft / tests / fft.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2012,2013,2014,2016,2017 by the GROMACS development team.
5  * Copyright (c) 2018,2019,2020,2021, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36 /*! \internal \file
37  * \brief
38  * Tests utilities for fft calculations.
39  *
40  * Current reference data is generated in double precision using the Reference
41  * build type, except for the compiler (Apple Clang).
42  *
43  * \author Roland Schulz <roland@utk.edu>
44  * \ingroup module_fft
45  */
46 #include "gmxpre.h"
47
48 #include "gromacs/fft/fft.h"
49
50 #include "config.h"
51
52 #include <algorithm>
53 #include <vector>
54
55 #include <gmock/gmock.h>
56 #include <gtest/gtest.h>
57
58 #include "gromacs/fft/gpu_3dfft.h"
59 #include "gromacs/fft/parallel_3dfft.h"
60 #include "gromacs/gpu_utils/clfftinitializer.h"
61 #if GMX_GPU
62 #    include "gromacs/gpu_utils/devicebuffer.h"
63 #endif
64 #include "gromacs/utility/stringutil.h"
65
66 #include "testutils/refdata.h"
67 #include "testutils/test_hardware_environment.h"
68 #include "testutils/testasserts.h"
69 #include "testutils/testmatchers.h"
70
71 namespace gmx
72 {
73 namespace test
74 {
75 namespace
76 {
77
78 /*! \brief Input data for FFT tests.
79  *
80  * TODO If we require compilers that all support C++11 user literals,
81  * then this array could be of type real, initialized with e.g. -3.5_r
82  * that does not suffer from implicit narrowing with brace
83  * initializers, and we would not have to do so much useless copying
84  * during the unit tests below.
85  */
86 const double inputdata[500] = {
87     // print ",\n".join([",".join(["%4s"%(random.randint(-99,99)/10.,) for i in range(25)]) for j in range(20)])
88     -3.5, 6.3,  1.2,  0.3,  1.1,  -5.7, 5.8,  -1.9, -6.3, -1.4, 7.4,  2.4,  -9.9, -7.2, 5.4,  6.1,
89     -1.9, -7.6, 1.4,  -3.5, 0.7,  5.6,  -4.2, -1.1, -4.4, -6.3, -7.2, 4.6,  -3.0, -0.9, 7.2,  2.5,
90     -3.6, 6.1,  -3.2, -2.1, 6.5,  -0.4, -9.0, 2.3,  8.4,  4.0,  -5.2, -9.0, 4.7,  -3.7, -2.0, -9.5,
91     -3.9, -3.6, 7.1,  0.8,  -0.6, 5.2,  -9.3, -4.5, 5.9,  2.2,  -5.8, 5.0,  1.2,  -0.1, 2.2,  0.2,
92     -7.7, 1.9,  -8.4, 4.4,  2.3,  -2.9, 6.7,  2.7,  5.8,  -3.6, 8.9,  8.9,  4.3,  9.1,  9.3,  -8.7,
93     4.1,  9.6,  -6.2, 6.6,  -9.3, 8.2,  4.5,  6.2,  9.4,  -8.0, -6.8, -3.3, 7.2,  1.7,  0.6,  -4.9,
94     9.8,  1.3,  3.2,  -0.2, 9.9,  4.4,  -9.9, -7.2, 4.4,  4.7,  7.2,  -0.3, 0.3,  -2.1, 8.4,  -2.1,
95     -6.1, 4.1,  -5.9, -2.2, -3.8, 5.2,  -8.2, -7.8, -8.8, 6.7,  -9.5, -4.2, 0.8,  8.3,  5.2,  -9.0,
96     8.7,  9.8,  -9.9, -7.8, -8.3, 9.0,  -2.8, -9.2, -9.6, 8.4,  2.5,  6.0,  -0.4, 1.3,  -0.5, 9.1,
97     -9.5, -0.8, 1.9,  -6.2, 4.3,  -3.8, 8.6,  -1.9, -2.1, -0.4, -7.1, -3.7, 9.1,  -6.4, -0.6, 2.5,
98     8.0,  -5.2, -9.8, -4.3, 4.5,  1.7,  9.3,  9.2,  1.0,  5.3,  -4.5, 6.4,  -6.6, 3.1,  -6.8, 2.1,
99     2.0,  7.3,  8.6,  5.0,  5.2,  0.4,  -7.1, 4.5,  -9.2, -9.1, 0.2,  -6.3, -1.1, -9.6, 7.4,  -3.7,
100     -5.5, 2.6,  -3.5, -0.7, 9.0,  9.8,  -8.0, 3.6,  3.0,  -2.2, -2.8, 0.8,  9.0,  2.8,  7.7,  -0.7,
101     -5.0, -1.8, -2.3, -0.4, -6.2, -9.1, -9.2, 0.5,  5.7,  -3.9, 2.1,  0.6,  0.4,  9.1,  7.4,  7.1,
102     -2.5, 7.3,  7.8,  -4.3, 6.3,  -0.8, -3.8, -1.5, 6.6,  2.3,  3.9,  -4.6, 5.8,  -7.4, 5.9,  2.8,
103     4.7,  3.9,  -5.4, 9.1,  -1.6, -1.9, -4.2, -2.6, 0.6,  -5.1, 1.8,  5.2,  4.0,  -6.2, 6.5,  -9.1,
104     0.5,  2.1,  7.1,  -8.6, 7.6,  -9.7, -4.6, -5.7, 6.1,  -1.8, -7.3, 9.4,  8.0,  -2.6, -1.8, 5.7,
105     9.3,  -7.9, 7.4,  6.3,  2.0,  9.6,  -4.5, -6.2, 6.1,  2.3,  0.8,  5.9,  -2.8, -3.5, -1.5, 6.0,
106     -4.9, 3.5,  7.7,  -4.2, -9.7, 2.4,  8.1,  5.9,  3.4,  -7.5, 7.5,  2.6,  4.7,  2.7,  2.2,  2.6,
107     6.2,  7.5,  0.2,  -6.4, -2.8, -0.5, -0.3, 0.4,  1.2,  3.5,  -4.0, -0.5, 9.3,  -7.2, 8.5,  -5.5,
108     -1.7, -5.3, 0.3,  3.9,  -3.6, -3.6, 4.7,  -8.1, 1.4,  4.0,  1.3,  -4.3, -8.8, -7.3, 6.3,  -7.5,
109     -9.0, 9.1,  4.5,  -1.9, 1.9,  9.9,  -1.7, -9.1, -5.1, 8.5,  -9.3, 2.1,  -5.8, -3.6, -0.8, -0.9,
110     -3.3, -2.7, 7.0,  -7.2, -5.0, 7.4,  -1.4, 0.0,  -4.5, -9.7, 0.7,  -1.0, -9.1, -5.3, 4.3,  3.4,
111     -6.6, 9.8,  -1.1, 8.9,  5.0,  2.9,  0.2,  -2.9, 0.8,  6.7,  -0.6, 0.6,  4.1,  5.3,  -1.7, -0.3,
112     4.2,  3.7,  -8.3, 4.0,  1.3,  6.3,  0.2,  1.3,  -1.1, -3.5, 2.8,  -7.7, 6.2,  -4.9, -9.9, 9.6,
113     3.0,  -9.2, -8.0, -3.9, 7.9,  -6.1, 6.0,  5.9,  9.6,  1.2,  6.2,  3.6,  2.1,  5.8,  9.2,  -8.8,
114     8.8,  -3.3, -9.2, 4.6,  1.8,  4.6,  2.9,  -2.7, 4.2,  7.3,  -0.4, 7.7,  -7.0, 2.1,  0.3,  3.7,
115     3.3,  -8.6, 9.8,  3.6,  3.1,  6.5,  -2.4, 7.8,  7.5,  8.4,  -2.8, -6.3, -5.1, -2.7, 9.3,  -0.8,
116     -9.2, 7.9,  8.9,  3.4,  0.1,  -5.3, -6.8, 4.9,  4.3,  -0.7, -2.2, -3.2, -7.5, -2.3, 0.0,  8.1,
117     -9.2, -2.3, -5.7, 2.1,  2.6,  2.0,  0.3,  -8.0, -2.0, -7.9, 6.6,  8.4,  4.0,  -6.2, -6.9, -7.2,
118     7.7,  -5.0, 5.3,  1.9,  -5.3, -7.5, 8.8,  8.3,  9.0,  8.1,  3.2,  1.2,  -5.4, -0.2, 2.1,  -5.2,
119     9.5,  5.9,  5.6,  -7.8,
120 };
121
122
123 class BaseFFTTest : public ::testing::Test
124 {
125 public:
126     BaseFFTTest() : flags_(GMX_FFT_FLAG_CONSERVATIVE) {}
127     ~BaseFFTTest() override { gmx_fft_cleanup(); }
128
129     TestReferenceData data_;
130     std::vector<real> in_, out_;
131     int               flags_;
132     // TODO: These tolerances are just something that has been observed
133     // to be sufficient to pass the tests.  It would be nicer to
134     // actually argue about why they are sufficient (or what is).
135     // Should work for both one-way and forward+backward transform.
136     FloatingPointTolerance defaultTolerance_ = relativeToleranceAsPrecisionDependentUlp(10.0, 64, 512);
137 };
138
139 class FFTTest : public BaseFFTTest
140 {
141 public:
142     FFTTest() : fft_(nullptr) { checker_.setDefaultTolerance(defaultTolerance_); }
143     ~FFTTest() override
144     {
145         if (fft_)
146         {
147             gmx_fft_destroy(fft_);
148         }
149     }
150     TestReferenceChecker checker_ = data_.rootChecker();
151     gmx_fft_t            fft_;
152 };
153
154 class ManyFFTTest : public BaseFFTTest
155 {
156 public:
157     ManyFFTTest() : fft_(nullptr) { checker_.setDefaultTolerance(defaultTolerance_); }
158     ~ManyFFTTest() override
159     {
160         if (fft_)
161         {
162             gmx_many_fft_destroy(fft_);
163         }
164     }
165     TestReferenceChecker checker_ = data_.rootChecker();
166     gmx_fft_t            fft_;
167 };
168
169
170 // TODO: Add tests for aligned/not-aligned input/output memory
171
172 class FFTTest1D : public FFTTest, public ::testing::WithParamInterface<int>
173 {
174 };
175
176 class FFTTest3D : public BaseFFTTest
177 {
178 public:
179     FFTTest3D() : fft_(nullptr) {}
180     ~FFTTest3D() override
181     {
182         if (fft_)
183         {
184             gmx_parallel_3dfft_destroy(fft_);
185         }
186     }
187     gmx_parallel_3dfft_t fft_;
188 };
189
190
191 TEST_P(FFTTest1D, Complex)
192 {
193     const int nx = GetParam();
194     ASSERT_LE(nx * 2, static_cast<int>(sizeof(inputdata) / sizeof(inputdata[0])));
195
196     in_ = std::vector<real>(nx * 2);
197     std::copy(inputdata, inputdata + nx * 2, in_.begin());
198     out_      = std::vector<real>(nx * 2);
199     real* in  = &in_[0];
200     real* out = &out_[0];
201
202     gmx_fft_init_1d(&fft_, nx, flags_);
203
204     gmx_fft_1d(fft_, GMX_FFT_FORWARD, in, out);
205     checker_.checkSequenceArray(nx * 2, out, "forward");
206     gmx_fft_1d(fft_, GMX_FFT_BACKWARD, in, out);
207     checker_.checkSequenceArray(nx * 2, out, "backward");
208 }
209
210 TEST_P(FFTTest1D, Real)
211 {
212     const int rx = GetParam();
213     const int cx = (rx / 2 + 1);
214     ASSERT_LE(cx * 2, static_cast<int>(sizeof(inputdata) / sizeof(inputdata[0])));
215
216     in_ = std::vector<real>(cx * 2);
217     std::copy(inputdata, inputdata + cx * 2, in_.begin());
218     out_      = std::vector<real>(cx * 2);
219     real* in  = &in_[0];
220     real* out = &out_[0];
221
222     gmx_fft_init_1d_real(&fft_, rx, flags_);
223
224     gmx_fft_1d_real(fft_, GMX_FFT_REAL_TO_COMPLEX, in, out);
225     checker_.checkSequenceArray(cx * 2, out, "forward");
226     gmx_fft_1d_real(fft_, GMX_FFT_COMPLEX_TO_REAL, in, out);
227     checker_.checkSequenceArray(rx, out, "backward");
228 }
229
230 INSTANTIATE_TEST_SUITE_P(7_8_25_36_60, FFTTest1D, ::testing::Values(7, 8, 25, 36, 60));
231
232
233 TEST_F(ManyFFTTest, Complex1DLength48Multi5Test)
234 {
235     const int nx = 48;
236     const int N  = 5;
237
238     in_ = std::vector<real>(nx * 2 * N);
239     std::copy(inputdata, inputdata + nx * 2 * N, in_.begin());
240     out_      = std::vector<real>(nx * 2 * N);
241     real* in  = &in_[0];
242     real* out = &out_[0];
243
244     gmx_fft_init_many_1d(&fft_, nx, N, flags_);
245
246     gmx_fft_many_1d(fft_, GMX_FFT_FORWARD, in, out);
247     checker_.checkSequenceArray(nx * 2 * N, out, "forward");
248     gmx_fft_many_1d(fft_, GMX_FFT_BACKWARD, in, out);
249     checker_.checkSequenceArray(nx * 2 * N, out, "backward");
250 }
251
252 TEST_F(ManyFFTTest, Real1DLength48Multi5Test)
253 {
254     const int rx = 48;
255     const int cx = (rx / 2 + 1);
256     const int N  = 5;
257
258     in_ = std::vector<real>(cx * 2 * N);
259     std::copy(inputdata, inputdata + cx * 2 * N, in_.begin());
260     out_      = std::vector<real>(cx * 2 * N);
261     real* in  = &in_[0];
262     real* out = &out_[0];
263
264     gmx_fft_init_many_1d_real(&fft_, rx, N, flags_);
265
266     gmx_fft_many_1d_real(fft_, GMX_FFT_REAL_TO_COMPLEX, in, out);
267     checker_.checkSequenceArray(cx * 2 * N, out, "forward");
268     gmx_fft_many_1d_real(fft_, GMX_FFT_COMPLEX_TO_REAL, in, out);
269     checker_.checkSequenceArray(rx * N, out, "backward");
270 }
271
272 TEST_F(FFTTest, Real2DLength18_15Test)
273 {
274     const int rx = 18;
275     const int cx = (rx / 2 + 1);
276     const int ny = 15;
277
278     in_ = std::vector<real>(cx * 2 * ny);
279     std::copy(inputdata, inputdata + cx * 2 * ny, in_.begin());
280     out_      = std::vector<real>(cx * 2 * ny);
281     real* in  = &in_[0];
282     real* out = &out_[0];
283
284     gmx_fft_init_2d_real(&fft_, rx, ny, flags_);
285
286     gmx_fft_2d_real(fft_, GMX_FFT_REAL_TO_COMPLEX, in, out);
287     checker_.checkSequenceArray(cx * 2 * ny, out, "forward");
288     //    known to be wrong for gmx_fft_mkl. And not used.
289     //    gmx_fft_2d_real(_fft,GMX_FFT_COMPLEX_TO_REAL,in,out);
290     //    _checker.checkSequenceArray(rx*ny, out, "backward");
291 }
292
293 namespace
294 {
295
296 /*! \brief Check that the real grid after forward and backward
297  * 3D transforms matches the input real grid. */
298 void checkRealGrid(const ivec           realGridSize,
299                    const ivec           realGridSizePadded,
300                    ArrayRef<const real> inputRealGrid,
301                    ArrayRef<real>       outputRealGridValues)
302 {
303     // Normalize the output (as the implementation does not
304     // normalize either FFT)
305     const real normalizationConstant = 1.0 / (realGridSize[XX] * realGridSize[YY] * realGridSize[ZZ]);
306     std::transform(outputRealGridValues.begin(),
307                    outputRealGridValues.end(),
308                    outputRealGridValues.begin(),
309                    [normalizationConstant](const real r) { return r * normalizationConstant; });
310     // Check the real grid, skipping unused data from the padding
311     const auto realGridTolerance = relativeToleranceAsFloatingPoint(10, 1e-6);
312     for (int i = 0; i < realGridSize[XX] * realGridSize[YY]; i++)
313     {
314         auto expected =
315                 arrayRefFromArray(inputRealGrid.data() + i * realGridSizePadded[ZZ], realGridSize[ZZ]);
316         auto actual = arrayRefFromArray(outputRealGridValues.data() + i * realGridSizePadded[ZZ],
317                                         realGridSize[ZZ]);
318         EXPECT_THAT(actual, Pointwise(RealEq(realGridTolerance), expected))
319                 << formatString("checking backward transform part %d", i);
320     }
321 }
322
323 } // namespace
324
325 // TODO: test with threads and more than 1 MPI ranks
326 TEST_F(FFTTest3D, Real5_6_9)
327 {
328     int        realGridSize[] = { 5, 6, 9 };
329     MPI_Comm   comm[]         = { MPI_COMM_NULL, MPI_COMM_NULL };
330     real*      rdata;
331     t_complex* cdata;
332     ivec       local_ndata, offset, realGridSizePadded, complexGridSizePadded, complex_order;
333     TestReferenceChecker checker(data_.rootChecker());
334     checker.setDefaultTolerance(defaultTolerance_);
335
336     gmx_parallel_3dfft_init(&fft_, realGridSize, &rdata, &cdata, comm, TRUE, 1);
337
338     gmx_parallel_3dfft_real_limits(fft_, local_ndata, offset, realGridSizePadded);
339     gmx_parallel_3dfft_complex_limits(fft_, complex_order, local_ndata, offset, complexGridSizePadded);
340     checker.checkVector(realGridSizePadded, "realGridSizePadded");
341     checker.checkVector(complexGridSizePadded, "complexGridSizePadded");
342     int size = complexGridSizePadded[0] * complexGridSizePadded[1] * complexGridSizePadded[2];
343     int sizeInBytes = size * sizeof(t_complex);
344     int sizeInReals = sizeInBytes / sizeof(real);
345
346     // Prepare the real grid
347     in_ = std::vector<real>(sizeInReals);
348     // Use std::copy to convert from double to real easily
349     std::copy(inputdata, inputdata + sizeInReals, in_.begin());
350     // Use memcpy to convert to t_complex easily
351     memcpy(rdata, in_.data(), sizeInBytes);
352
353     // Do the forward FFT to compute the complex grid
354     gmx_parallel_3dfft_execute(fft_, GMX_FFT_REAL_TO_COMPLEX, 0, nullptr);
355
356     // Check the complex grid (NB this data has not been normalized)
357     ArrayRef<real> complexGridValues = arrayRefFromArray(reinterpret_cast<real*>(cdata), size * 2);
358     checker.checkSequence(
359             complexGridValues.begin(), complexGridValues.end(), "ComplexGridAfterRealToComplex");
360
361     // Do the back transform
362     gmx_parallel_3dfft_execute(fft_, GMX_FFT_COMPLEX_TO_REAL, 0, nullptr);
363
364     ArrayRef<real> outputRealGridValues = arrayRefFromArray(
365             rdata, realGridSizePadded[XX] * realGridSizePadded[YY] * realGridSizePadded[ZZ]);
366     checkRealGrid(realGridSize, realGridSizePadded, in_, outputRealGridValues);
367 }
368
369 #if GMX_GPU
370
371 /*! \brief Whether the FFT is in- or out-of-place
372  *
373  *  DPCPP uses oneMKL, which seems to have troubles with out-of-place
374  *  transforms. */
375 constexpr bool sc_performOutOfPlaceFFT = !((GMX_SYCL_DPCPP == 1) && (GMX_FFT_MKL == 1));
376
377 /*! \brief Return the output grid depending on whether in- or out-of
378  * place FFT is used
379  *
380  * Some versions of clang complain of unused code if we would just
381  * branch on the value of sc_performOutOfPlaceFFT at run time, because
382  * in any single configuration there would indeed be unused code. So
383  * the two template specializations are needed so that the compiler
384  * only compiles the template that is used. */
385 template<bool performOutOfPlaceFFT>
386 DeviceBuffer<float>* actualOutputGrid(DeviceBuffer<float>* realGrid, DeviceBuffer<float>* complexGrid);
387
388 #    if GMX_SYCL_DPCPP && GMX_FFT_MKL
389
390 template<>
391 DeviceBuffer<float>* actualOutputGrid<false>(DeviceBuffer<float>* realGrid,
392                                              DeviceBuffer<float>* /* complexGrid */)
393 {
394     return realGrid;
395 };
396
397 #    else
398
399 template<>
400 DeviceBuffer<float>* actualOutputGrid<true>(DeviceBuffer<float>* /* realGrid */, DeviceBuffer<float>* complexGrid)
401 {
402     return complexGrid;
403 }
404
405 #    endif
406
407 TEST_F(FFTTest3D, GpuReal5_6_9)
408 {
409     // Ensure library resources are managed appropriately
410     ClfftInitializer clfftInitializer;
411     for (const auto& testDevice : getTestHardwareEnvironment()->getTestDeviceList())
412     {
413         TestReferenceChecker checker(data_.rootChecker()); // Must be inside the loop to avoid warnings
414         checker.setDefaultTolerance(defaultTolerance_);
415
416         const DeviceContext& deviceContext = testDevice->deviceContext();
417         setActiveDevice(testDevice->deviceInfo());
418         const DeviceStream& deviceStream = testDevice->deviceStream();
419
420         ivec realGridSize       = { 5, 6, 9 };
421         ivec realGridSizePadded = { realGridSize[XX], realGridSize[YY], (realGridSize[ZZ] / 2 + 1) * 2 };
422         ivec complexGridSizePadded = { realGridSize[XX], realGridSize[YY], (realGridSize[ZZ] / 2) + 1 };
423
424         checker.checkVector(realGridSizePadded, "realGridSizePadded");
425         checker.checkVector(complexGridSizePadded, "complexGridSizePadded");
426
427         int size = complexGridSizePadded[0] * complexGridSizePadded[1] * complexGridSizePadded[2];
428         int sizeInReals = size * 2;
429         GMX_RELEASE_ASSERT(sizeof(inputdata) / sizeof(inputdata[0]) >= size_t(sizeInReals),
430                            "Size of inputdata is too small");
431
432         // Set up the complex grid. Complex numbers take twice the
433         // memory.
434         std::vector<float> complexGridValues(sizeInReals);
435         in_.resize(sizeInReals);
436         // Use std::copy to convert from double to real easily
437         std::copy(inputdata, inputdata + sizeInReals, in_.begin());
438
439 #    if GMX_GPU_CUDA
440         const FftBackend backend = FftBackend::Cufft;
441 #    elif GMX_GPU_OPENCL
442         const FftBackend backend = FftBackend::Ocl;
443 #    elif GMX_GPU_SYCL
444 #        if GMX_SYCL_HIPSYCL
445 #            if GMX_HIPSYCL_HAVE_HIP_TARGET
446         const FftBackend backend = FftBackend::SyclRocfft;
447 #            else
448         // Use stub backend so compilation succeeds
449         const FftBackend backend = FftBackend::Sycl;
450         // Don't complain about unused reference data
451         checker.disableUnusedEntriesCheck();
452         // Skip the rest of the test
453         GTEST_SKIP() << "Only rocFFT backend is supported with hipSYCL";
454 #            endif
455 #        elif GMX_SYCL_DPCPP
456 #            if GMX_FFT_MKL
457         const FftBackend backend = FftBackend::SyclMkl;
458 #            else
459         // Use stub backend so compilation succeeds
460         const FftBackend backend = FftBackend::Sycl;
461         // Don't complain about unused reference data
462         checker.disableUnusedEntriesCheck();
463         // Skip the rest of the test
464         GTEST_SKIP() << "Only MKL backend is supported with DPC++";
465 #            endif
466 #        else
467 #            error "Unsupported SYCL implementation"
468 #        endif
469 #    endif
470
471         SCOPED_TRACE("Allocating the device buffers");
472         DeviceBuffer<float> realGrid, complexGrid;
473         allocateDeviceBuffer(&realGrid, in_.size(), deviceContext);
474         if (sc_performOutOfPlaceFFT)
475         {
476             allocateDeviceBuffer(&complexGrid, complexGridValues.size(), deviceContext);
477         }
478
479         MPI_Comm           comm                    = MPI_COMM_NULL;
480         const bool         allocateGrid            = false;
481         std::array<int, 1> gridSizesInXForEachRank = { 0 };
482         std::array<int, 1> gridSizesInYForEachRank = { 0 };
483         const int          nz                      = realGridSize[ZZ];
484         Gpu3dFft           gpu3dFft(backend,
485                           allocateGrid,
486                           comm,
487                           gridSizesInXForEachRank,
488                           gridSizesInYForEachRank,
489                           nz,
490                           sc_performOutOfPlaceFFT,
491                           deviceContext,
492                           deviceStream,
493                           realGridSize,
494                           realGridSizePadded,
495                           complexGridSizePadded,
496                           &realGrid,
497                           actualOutputGrid<sc_performOutOfPlaceFFT>(&realGrid, &complexGrid));
498
499         // Transfer the real grid input data for the FFT
500         copyToDeviceBuffer(
501                 &realGrid, in_.data(), 0, in_.size(), deviceStream, GpuApiCallBehavior::Sync, nullptr);
502
503         // Do the forward FFT to compute the complex grid
504         CommandEvent* timingEvent = nullptr;
505         gpu3dFft.perform3dFft(GMX_FFT_REAL_TO_COMPLEX, timingEvent);
506         deviceStream.synchronize();
507
508         // Check the complex grid (NB this data has not been normalized)
509         copyFromDeviceBuffer(complexGridValues.data(),
510                              actualOutputGrid<sc_performOutOfPlaceFFT>(&realGrid, &complexGrid),
511                              0,
512                              complexGridValues.size(),
513                              deviceStream,
514                              GpuApiCallBehavior::Sync,
515                              nullptr);
516         checker.checkSequence(
517                 complexGridValues.begin(), complexGridValues.end(), "ComplexGridAfterRealToComplex");
518
519         std::vector<float> outputRealGridValues(in_.size());
520         if (sc_performOutOfPlaceFFT)
521         {
522             // Clear the real grid input data for the FFT so we can
523             // compute the back transform into it and observe that it did
524             // the work expected.
525             copyToDeviceBuffer(&realGrid,
526                                outputRealGridValues.data(),
527                                0,
528                                outputRealGridValues.size(),
529                                deviceStream,
530                                GpuApiCallBehavior::Sync,
531                                nullptr);
532         }
533
534         SCOPED_TRACE("Doing the back transform");
535         gpu3dFft.perform3dFft(GMX_FFT_COMPLEX_TO_REAL, timingEvent);
536         deviceStream.synchronize();
537
538         // Transfer the real grid back from the device
539         copyFromDeviceBuffer(outputRealGridValues.data(),
540                              &realGrid,
541                              0,
542                              outputRealGridValues.size(),
543                              deviceStream,
544                              GpuApiCallBehavior::Sync,
545                              nullptr);
546
547         checkRealGrid(realGridSize, realGridSizePadded, in_, outputRealGridValues);
548
549         SCOPED_TRACE("Cleaning up");
550         freeDeviceBuffer(&realGrid);
551         if (sc_performOutOfPlaceFFT)
552         {
553             freeDeviceBuffer(&complexGrid);
554         }
555     }
556 }
557
558 #endif
559
560 } // namespace
561 } // namespace test
562 } // namespace gmx