Redesign GPU FFT abstraction
[alexxy/gromacs.git] / src / gromacs / fft / tests / fft.cpp
index 7d0c7186d4090dd6307f8ae87d29ecbd2ff9e4ee..ce3717e210ffdf72c29f90dd90ac50613801e087 100644 (file)
@@ -400,17 +400,31 @@ TEST_F(FFTTest3D, GpuReal5_6_9)
         allocateDeviceBuffer(&realGrid, in_.size(), deviceContext);
         allocateDeviceBuffer(&complexGrid, complexGridValues.size(), deviceContext);
 
-        const bool useDecomposition     = false;
-        const bool performOutOfPlaceFFT = true;
-        Gpu3dFft   gpu3dFft(realGridSize,
-                          realGridSizePadded,
-                          complexGridSizePadded,
-                          useDecomposition,
+#    if GMX_GPU_CUDA
+        const FftBackend backend = FftBackend::Cufft;
+#    elif GMX_GPU_OPENCL
+        const FftBackend backend = FftBackend::Ocl;
+#    endif
+        const bool         performOutOfPlaceFFT    = true;
+        const MPI_Comm     comm                    = MPI_COMM_NULL;
+        const bool         allocateGrid            = false;
+        std::array<int, 1> gridSizesInXForEachRank = { 0 };
+        std::array<int, 1> gridSizesInYForEachRank = { 0 };
+        const int          nz                      = realGridSize[ZZ];
+        Gpu3dFft           gpu3dFft(backend,
+                          allocateGrid,
+                          comm,
+                          gridSizesInXForEachRank,
+                          gridSizesInYForEachRank,
+                          nz,
                           performOutOfPlaceFFT,
                           deviceContext,
                           deviceStream,
-                          realGrid,
-                          complexGrid);
+                          realGridSize,
+                          realGridSizePadded,
+                          complexGridSizePadded,
+                          &realGrid,
+                          &complexGrid);
 
         // Transfer the real grid input data for the FFT
         copyToDeviceBuffer(