Redesign GPU FFT abstraction
[alexxy/gromacs.git] / src / gromacs / fft / gpu_3dfft_ocl.cpp
index 69a44974459fcf9cd26c288317b4688fb2338be1..76ae53560874d7d141b7bdcaf0ccbeb8834005de 100644 (file)
@@ -43,7 +43,7 @@
 
 #include "gmxpre.h"
 
-#include "gpu_3dfft.h"
+#include "gpu_3dfft_ocl.h"
 
 #include <array>
 #include <vector>
 
 namespace gmx
 {
-
-class Gpu3dFft::Impl
-{
-public:
-    Impl(ivec                 realGridSize,
-         ivec                 realGridSizePadded,
-         ivec                 complexGridSizePadded,
-         bool                 useDecomposition,
-         bool                 performOutOfPlaceFFT,
-         const DeviceContext& context,
-         const DeviceStream&  pmeStream,
-         DeviceBuffer<float>  realGrid,
-         DeviceBuffer<float>  complexGrid);
-    ~Impl();
-
-    clfftPlanHandle               planR2C_;
-    clfftPlanHandle               planC2R_;
-    std::vector<cl_command_queue> commandStreams_;
-    cl_mem                        realGrid_;
-    cl_mem                        complexGrid_;
-};
-
 //! Throws the exception on clFFT error
 static void handleClfftError(clfftStatus status, const char* msg)
 {
@@ -91,18 +69,24 @@ static void handleClfftError(clfftStatus status, const char* msg)
     }
 }
 
-Gpu3dFft::Impl::Impl(ivec                 realGridSize,
-                     ivec                 realGridSizePadded,
-                     ivec                 complexGridSizePadded,
-                     const bool           useDecomposition,
-                     const bool           performOutOfPlaceFFT,
-                     const DeviceContext& context,
-                     const DeviceStream&  pmeStream,
-                     DeviceBuffer<float>  realGrid,
-                     DeviceBuffer<float>  complexGrid) :
-    realGrid_(realGrid), complexGrid_(complexGrid)
+Gpu3dFft::ImplOcl::ImplOcl(bool allocateGrids,
+                           MPI_Comm /*comm*/,
+                           ArrayRef<const int> gridSizesInXForEachRank,
+                           ArrayRef<const int> gridSizesInYForEachRank,
+                           const int /*nz*/,
+                           bool                 performOutOfPlaceFFT,
+                           const DeviceContext& context,
+                           const DeviceStream&  pmeStream,
+                           ivec                 realGridSize,
+                           ivec                 realGridSizePadded,
+                           ivec                 complexGridSizePadded,
+                           DeviceBuffer<float>* realGrid,
+                           DeviceBuffer<float>* complexGrid) :
+    realGrid_(*realGrid), complexGrid_(*complexGrid)
 {
-    GMX_RELEASE_ASSERT(!useDecomposition, "FFT decomposition not implemented");
+    GMX_RELEASE_ASSERT(allocateGrids == false, "Grids needs to be pre-allocated");
+    GMX_RELEASE_ASSERT(gridSizesInXForEachRank.size() == 1 && gridSizesInYForEachRank.size() == 1,
+                       "FFT decomposition not implemented with OpenCL backend");
 
     cl_context clContext = context.context();
     commandStreams_.push_back(pmeStream.stream());
@@ -157,13 +141,13 @@ Gpu3dFft::Impl::Impl(ivec                 realGridSize,
     // TODO: disable last transpose (clfftSetPlanTransposeResult)
 }
 
-Gpu3dFft::Impl::~Impl()
+Gpu3dFft::ImplOcl::~ImplOcl()
 {
     clfftDestroyPlan(&planR2C_);
     clfftDestroyPlan(&planC2R_);
 }
 
-void Gpu3dFft::perform3dFft(gmx_fft_direction dir, CommandEvent* timingEvent)
+void Gpu3dFft::ImplOcl::perform3dFft(gmx_fft_direction dir, CommandEvent* timingEvent)
 {
     cl_mem                            tempBuffer = nullptr;
     constexpr std::array<cl_event, 0> waitEvents{ {} };
@@ -175,24 +159,24 @@ void Gpu3dFft::perform3dFft(gmx_fft_direction dir, CommandEvent* timingEvent)
     switch (dir)
     {
         case GMX_FFT_REAL_TO_COMPLEX:
-            plan        = impl_->planR2C_;
+            plan        = planR2C_;
             direction   = CLFFT_FORWARD;
-            inputGrids  = &impl_->realGrid_;
-            outputGrids = &impl_->complexGrid_;
+            inputGrids  = &realGrid_;
+            outputGrids = &complexGrid_;
             break;
         case GMX_FFT_COMPLEX_TO_REAL:
-            plan        = impl_->planC2R_;
+            plan        = planC2R_;
             direction   = CLFFT_BACKWARD;
-            inputGrids  = &impl_->complexGrid_;
-            outputGrids = &impl_->realGrid_;
+            inputGrids  = &complexGrid_;
+            outputGrids = &realGrid_;
             break;
         default:
             GMX_THROW(NotImplementedError("The chosen 3D-FFT case is not implemented on GPUs"));
     }
     handleClfftError(clfftEnqueueTransform(plan,
                                            direction,
-                                           impl_->commandStreams_.size(),
-                                           impl_->commandStreams_.data(),
+                                           commandStreams_.size(),
+                                           commandStreams_.data(),
                                            waitEvents.size(),
                                            waitEvents.data(),
                                            timingEvent,
@@ -202,27 +186,4 @@ void Gpu3dFft::perform3dFft(gmx_fft_direction dir, CommandEvent* timingEvent)
                      "clFFT execution failure");
 }
 
-Gpu3dFft::Gpu3dFft(ivec                 realGridSize,
-                   ivec                 realGridSizePadded,
-                   ivec                 complexGridSizePadded,
-                   const bool           useDecomposition,
-                   const bool           performOutOfPlaceFFT,
-                   const DeviceContext& context,
-                   const DeviceStream&  pmeStream,
-                   DeviceBuffer<float>  realGrid,
-                   DeviceBuffer<float>  complexGrid) :
-    impl_(std::make_unique<Impl>(realGridSize,
-                                 realGridSizePadded,
-                                 complexGridSizePadded,
-                                 useDecomposition,
-                                 performOutOfPlaceFFT,
-                                 context,
-                                 pmeStream,
-                                 realGrid,
-                                 complexGrid))
-{
-}
-
-Gpu3dFft::~Gpu3dFft() = default;
-
 } // namespace gmx