Add HeFFTe based FFT backend
[alexxy/gromacs.git] / src / gromacs / fft / gpu_3dfft.cpp
index c027c5d08a80bd761d6b4a116d4b73a2a0578ca1..9b931cfb2fd188c5ed63f1aded7474b3345243ac 100644 (file)
 #    include "gpu_3dfft_sycl.h"
 #endif
 
+#if Heffte_FOUND
+#    include "gpu_3dfft_heffte.h"
+#endif
+
 #include "gromacs/utility/arrayref.h"
 #include "gromacs/utility/exceptions.h"
 
@@ -101,7 +105,9 @@ Gpu3dFft::Gpu3dFft(FftBackend           backend,
                                                           realGrid,
                                                           complexGrid);
             break;
-        default: GMX_THROW(InternalError("Unsupported FFT backend requested"));
+        default:
+            GMX_RELEASE_ASSERT(backend == FftBackend::HeFFTe_CUDA,
+                               "Unsupported FFT backend requested");
     }
 #    elif GMX_GPU_OPENCL
     switch (backend)
@@ -144,6 +150,35 @@ Gpu3dFft::Gpu3dFft(FftBackend           backend,
         default: GMX_THROW(InternalError("Unsupported FFT backend requested"));
     }
 #    endif
+
+#    if Heffte_FOUND
+    switch (backend)
+    {
+        case FftBackend::HeFFTe_CUDA:
+            GMX_RELEASE_ASSERT(
+                    GMX_GPU_CUDA,
+                    "HeFFTe_CUDA FFT backend is supported only with GROMACS compiled with CUDA");
+            GMX_RELEASE_ASSERT(heffte::backend::is_enabled<heffte::backend::cufft>::value,
+                               "HeFFTe not compiled with CUDA support");
+            impl_ = std::make_unique<Gpu3dFft::ImplHeFfte<heffte::backend::cufft>>(
+                    allocateGrids,
+                    comm,
+                    gridSizesInXForEachRank,
+                    gridSizesInYForEachRank,
+                    nz,
+                    performOutOfPlaceFFT,
+                    context,
+                    pmeStream,
+                    realGridSize,
+                    realGridSizePadded,
+                    complexGridSizePadded,
+                    realGrid,
+                    complexGrid);
+
+            break;
+        default: GMX_RELEASE_ASSERT(impl_ != nullptr, "Unsupported FFT backend requested");
+    }
+#    endif
 }
 
 #else