#ifndef GMX_EWALD_PME_GPU_3DFFT_H
#define GMX_EWALD_PME_GPU_3DFFT_H
-#include "config.h"
+#include <memory>
-#include <vector>
-
-#if GMX_GPU_CUDA
-# include <cufft.h>
-
-# include "gromacs/gpu_utils/gputraits.cuh"
-#elif GMX_GPU_OPENCL
-# include <clFFT.h>
-
-# include "gromacs/gpu_utils/gmxopencl.h"
-# include "gromacs/gpu_utils/gputraits_ocl.h"
-#elif GMX_GPU_SYCL
-# include "gromacs/gpu_utils/gputraits_sycl.h"
-#endif
-
-#include "gromacs/fft/fft.h" // for the enum gmx_fft_direction
+#include "gromacs/fft/fft.h"
+#include "gromacs/gpu_utils/devicebuffer_datatype.h"
+#include "gromacs/gpu_utils/gputraits.h"
+class DeviceContext;
+class DeviceStream;
struct PmeGpu;
/*! \internal \brief
{
public:
/*! \brief
- * Constructs CUDA/OpenCL FFT plans for performing 3D FFT on a PME grid.
+ * Constructs GPU FFT plans for performing 3D FFT on a PME grid.
*
- * \param[in] pmeGpu The PME GPU structure.
- * \param[in] gridIndex The index of the grid on which to perform the calculations.
+ * \param[in] realGridSize Dimensions of the real grid
+ * \param[in] realGridSizePadded Dimensions of the real grid with padding
+ * \param[in] complexGridSizePadded Dimensions of the real grid with padding
+ * \param[in] useDecomposition Whether PME decomposition will be used
+ * \param[in] performOutOfPlaceFFT Whether the FFT will be performed out-of-place
+ * \param[in] context GPU context.
+ * \param[in] pmeStream GPU stream for PME.
+ * \param[in] realGrid Device buffer of floats for the real grid
+ * \param[in] complexGrid Device buffer of complex floats for the complex grid
*/
- GpuParallel3dFft(const PmeGpu* pmeGpu, int gridIndex);
+ GpuParallel3dFft(ivec realGridSize,
+ ivec realGridSizePadded,
+ ivec complexGridSizePadded,
+ bool useDecomposition,
+ bool performOutOfPlaceFFT,
+ const DeviceContext& context,
+ const DeviceStream& pmeStream,
+ DeviceBuffer<float> realGrid,
+ DeviceBuffer<float> complexGrid);
+
/*! \brief Destroys the FFT plans. */
~GpuParallel3dFft();
/*! \brief Performs the FFT transform in given direction
void perform3dFft(gmx_fft_direction dir, CommandEvent* timingEvent);
private:
-#if GMX_GPU_CUDA
- cufftHandle planR2C_;
- cufftHandle planC2R_;
- cufftReal* realGrid_;
- cufftComplex* complexGrid_;
-#elif GMX_GPU_OPENCL
- clfftPlanHandle planR2C_;
- clfftPlanHandle planC2R_;
- std::vector<cl_command_queue> deviceStreams_;
- cl_mem realGrid_;
- cl_mem complexGrid_;
-#endif
+ class Impl;
+ std::unique_ptr<Impl> impl_;
};
#endif