Decouple PME GPU 3DFFT from PME GPU module
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gpu_3dfft.h
index d71e43522ca59c73849d8c677c6e93393c186aa6..5939dea0c685258ef177150aa02ed433edc5dc13 100644 (file)
 #ifndef GMX_EWALD_PME_GPU_3DFFT_H
 #define GMX_EWALD_PME_GPU_3DFFT_H
 
-#include "config.h"
+#include <memory>
 
-#include <vector>
-
-#if GMX_GPU_CUDA
-#    include <cufft.h>
-
-#    include "gromacs/gpu_utils/gputraits.cuh"
-#elif GMX_GPU_OPENCL
-#    include <clFFT.h>
-
-#    include "gromacs/gpu_utils/gmxopencl.h"
-#    include "gromacs/gpu_utils/gputraits_ocl.h"
-#elif GMX_GPU_SYCL
-#    include "gromacs/gpu_utils/gputraits_sycl.h"
-#endif
-
-#include "gromacs/fft/fft.h" // for the enum gmx_fft_direction
+#include "gromacs/fft/fft.h"
+#include "gromacs/gpu_utils/devicebuffer_datatype.h"
+#include "gromacs/gpu_utils/gputraits.h"
 
+class DeviceContext;
+class DeviceStream;
 struct PmeGpu;
 
 /*! \internal \brief
@@ -73,12 +62,28 @@ class GpuParallel3dFft
 {
 public:
     /*! \brief
-     * Constructs CUDA/OpenCL FFT plans for performing 3D FFT on a PME grid.
+     * Constructs GPU FFT plans for performing 3D FFT on a PME grid.
      *
-     * \param[in] pmeGpu                  The PME GPU structure.
-     * \param[in] gridIndex               The index of the grid on which to perform the calculations.
+     * \param[in]  realGridSize           Dimensions of the real grid
+     * \param[in]  realGridSizePadded     Dimensions of the real grid with padding
+     * \param[in]  complexGridSizePadded  Dimensions of the real grid with padding
+     * \param[in]  useDecomposition       Whether PME decomposition will be used
+     * \param[in]  performOutOfPlaceFFT   Whether the FFT will be performed out-of-place
+     * \param[in]  context                GPU context.
+     * \param[in]  pmeStream              GPU stream for PME.
+     * \param[in]  realGrid               Device buffer of floats for the real grid
+     * \param[in]  complexGrid            Device buffer of complex floats for the complex grid
      */
-    GpuParallel3dFft(const PmeGpu* pmeGpu, int gridIndex);
+    GpuParallel3dFft(ivec                 realGridSize,
+                     ivec                 realGridSizePadded,
+                     ivec                 complexGridSizePadded,
+                     bool                 useDecomposition,
+                     bool                 performOutOfPlaceFFT,
+                     const DeviceContext& context,
+                     const DeviceStream&  pmeStream,
+                     DeviceBuffer<float>  realGrid,
+                     DeviceBuffer<float>  complexGrid);
+
     /*! \brief Destroys the FFT plans. */
     ~GpuParallel3dFft();
     /*! \brief Performs the FFT transform in given direction
@@ -89,18 +94,8 @@ public:
     void perform3dFft(gmx_fft_direction dir, CommandEvent* timingEvent);
 
 private:
-#if GMX_GPU_CUDA
-    cufftHandle   planR2C_;
-    cufftHandle   planC2R_;
-    cufftReal*    realGrid_;
-    cufftComplex* complexGrid_;
-#elif GMX_GPU_OPENCL
-    clfftPlanHandle               planR2C_;
-    clfftPlanHandle               planC2R_;
-    std::vector<cl_command_queue> deviceStreams_;
-    cl_mem                        realGrid_;
-    cl_mem                        complexGrid_;
-#endif
+    class Impl;
+    std::unique_ptr<Impl> impl_;
 };
 
 #endif