Redesign GPU FFT abstraction
[alexxy/gromacs.git] / src / gromacs / fft / gpu_3dfft_cufft.cu
similarity index 66%
rename from src/gromacs/fft/gpu_3dfft.cu
rename to src/gromacs/fft/gpu_3dfft_cufft.cu
index 78f3ba90dcc0cf0bb2f70e2274d73ff5d9f25c2a..5ccdb9842e9335967b8d161edb644302d9dc57e5 100644 (file)
 
 #include "gmxpre.h"
 
-#include "gpu_3dfft.h"
-
-#include <cufft.h>
+#include "gpu_3dfft_cufft.h"
 
 #include "gromacs/gpu_utils/device_stream.h"
+#include "gromacs/utility/arrayref.h"
 #include "gromacs/utility/fatalerror.h"
 #include "gromacs/utility/gmxassert.h"
 
 namespace gmx
 {
-
-class Gpu3dFft::Impl
-{
-public:
-    Impl(ivec                 realGridSize,
-         ivec                 realGridSizePadded,
-         ivec                 complexGridSizePadded,
-         bool                 useDecomposition,
-         bool                 performOutOfPlaceFFT,
-         const DeviceContext& context,
-         const DeviceStream&  pmeStream,
-         DeviceBuffer<float>  realGrid,
-         DeviceBuffer<float>  complexGrid);
-    ~Impl();
-
-    cufftHandle   planR2C_;
-    cufftHandle   planC2R_;
-    cufftReal*    realGrid_;
-    cufftComplex* complexGrid_;
-};
-
 static void handleCufftError(cufftResult_t status, const char* msg)
 {
     if (status != CUFFT_SUCCESS)
@@ -82,19 +60,25 @@ static void handleCufftError(cufftResult_t status, const char* msg)
     }
 }
 
-Gpu3dFft::Impl::Impl(ivec       realGridSize,
-                     ivec       realGridSizePadded,
-                     ivec       complexGridSizePadded,
-                     const bool useDecomposition,
-                     const bool /*performOutOfPlaceFFT*/,
-                     const DeviceContext& /*context*/,
-                     const DeviceStream& pmeStream,
-                     DeviceBuffer<float> realGrid,
-                     DeviceBuffer<float> complexGrid) :
-    realGrid_(reinterpret_cast<cufftReal*>(realGrid)),
-    complexGrid_(reinterpret_cast<cufftComplex*>(complexGrid))
+Gpu3dFft::ImplCuFft::ImplCuFft(bool allocateGrids,
+                               MPI_Comm /*comm*/,
+                               ArrayRef<const int> gridSizesInXForEachRank,
+                               ArrayRef<const int> gridSizesInYForEachRank,
+                               const int /*nz*/,
+                               bool /*performOutOfPlaceFFT*/,
+                               const DeviceContext& /*context*/,
+                               const DeviceStream&  pmeStream,
+                               ivec                 realGridSize,
+                               ivec                 realGridSizePadded,
+                               ivec                 complexGridSizePadded,
+                               DeviceBuffer<float>* realGrid,
+                               DeviceBuffer<float>* complexGrid) :
+    realGrid_(reinterpret_cast<cufftReal*>(*realGrid)),
+    complexGrid_(reinterpret_cast<cufftComplex*>(*complexGrid))
 {
-    GMX_RELEASE_ASSERT(!useDecomposition, "FFT decomposition not implemented");
+    GMX_RELEASE_ASSERT(allocateGrids == false, "Grids needs to be pre-allocated");
+    GMX_RELEASE_ASSERT(gridSizesInXForEachRank.size() == 1 && gridSizesInYForEachRank.size() == 1,
+                       "FFT decomposition not implemented with cuFFT backend");
 
     const int complexGridSizePaddedTotal =
             complexGridSizePadded[XX] * complexGridSizePadded[YY] * complexGridSizePadded[ZZ];
@@ -151,7 +135,7 @@ Gpu3dFft::Impl::Impl(ivec       realGridSize,
     handleCufftError(result, "cufftSetStream C2R failure");
 }
 
-Gpu3dFft::Impl::~Impl()
+Gpu3dFft::ImplCuFft::~ImplCuFft()
 {
     cufftResult_t result;
     result = cufftDestroy(planR2C_);
@@ -160,42 +144,19 @@ Gpu3dFft::Impl::~Impl()
     handleCufftError(result, "cufftDestroy C2R failure");
 }
 
-void Gpu3dFft::perform3dFft(gmx_fft_direction dir, CommandEvent* /*timingEvent*/)
+void Gpu3dFft::ImplCuFft::perform3dFft(gmx_fft_direction dir, CommandEvent* /*timingEvent*/)
 {
     cufftResult_t result;
     if (dir == GMX_FFT_REAL_TO_COMPLEX)
     {
-        result = cufftExecR2C(impl_->planR2C_, impl_->realGrid_, impl_->complexGrid_);
+        result = cufftExecR2C(planR2C_, realGrid_, complexGrid_);
         handleCufftError(result, "cuFFT R2C execution failure");
     }
     else
     {
-        result = cufftExecC2R(impl_->planC2R_, impl_->complexGrid_, impl_->realGrid_);
+        result = cufftExecC2R(planC2R_, complexGrid_, realGrid_);
         handleCufftError(result, "cuFFT C2R execution failure");
     }
 }
 
-Gpu3dFft::Gpu3dFft(ivec                 realGridSize,
-                   ivec                 realGridSizePadded,
-                   ivec                 complexGridSizePadded,
-                   const bool           useDecomposition,
-                   const bool           performOutOfPlaceFFT,
-                   const DeviceContext& context,
-                   const DeviceStream&  pmeStream,
-                   DeviceBuffer<float>  realGrid,
-                   DeviceBuffer<float>  complexGrid) :
-    impl_(std::make_unique<Impl>(realGridSize,
-                                 realGridSizePadded,
-                                 complexGridSizePadded,
-                                 useDecomposition,
-                                 performOutOfPlaceFFT,
-                                 context,
-                                 pmeStream,
-                                 realGrid,
-                                 complexGrid))
-{
-}
-
-Gpu3dFft::~Gpu3dFft() = default;
-
 } // namespace gmx