Redesign GPU FFT abstraction
[alexxy/gromacs.git] / src / gromacs / fft / gpu_3dfft.h
index 65d3f6f03bd18241005cdf1dfb9220aeef2f20d7..7b2c6376546336ff4c79f21ca3ddfed08a686268 100644 (file)
@@ -38,6 +38,7 @@
  *
  *  \author Aleksei Iupinov <a.yupinov@gmail.com>
  *  \author Mark Abraham <mark.j.abraham@gmail.com>
+ *  \author Gaurav Garg <gaugarg@nvidia.com>
  *  \ingroup module_fft
  */
 
@@ -49,6 +50,7 @@
 #include "gromacs/fft/fft.h"
 #include "gromacs/gpu_utils/devicebuffer_datatype.h"
 #include "gromacs/gpu_utils/gputraits.h"
+#include "gromacs/utility/gmxmpi.h"
 
 class DeviceContext;
 class DeviceStream;
@@ -56,35 +58,59 @@ class DeviceStream;
 namespace gmx
 {
 
+template<typename T>
+class ArrayRef;
+
+/*! \internal \brief
+ * Enum specifying all GPU FFT backends supported by GROMACS
+ * Some of the backends support only single GPU, some only multi-node, multi-GPU
+ */
+enum class FftBackend
+{
+    Cufft, // supports only single-GPU
+    Ocl,   // supports only single-GPU
+    Sycl,  // Not supported currently
+    Count
+};
+
 /*! \internal \brief
  * A 3D FFT class for performing R2C/C2R transforms
- * \todo Make this class actually parallel over multiple GPUs
  */
 class Gpu3dFft
 {
 public:
     /*! \brief
-     * Constructs GPU FFT plans for performing 3D FFT on a PME grid.
+     * Construct 3D FFT object for given backend
      *
-     * \param[in]  realGridSize           Dimensions of the real grid
-     * \param[in]  realGridSizePadded     Dimensions of the real grid with padding
-     * \param[in]  complexGridSizePadded  Dimensions of the real grid with padding
-     * \param[in]  useDecomposition       Whether PME decomposition will be used
-     * \param[in]  performOutOfPlaceFFT   Whether the FFT will be performed out-of-place
-     * \param[in]  context                GPU context.
-     * \param[in]  pmeStream              GPU stream for PME.
-     * \param[in]  realGrid               Device buffer of floats for the real grid
-     * \param[in]  complexGrid            Device buffer of complex floats for the complex grid
+     * \param[in]  backend                      FFT backend to be instantiated
+     * \param[in]  allocateGrids                True if fft grids are to be allocated, false if pre-allocated
+     * \param[in]  comm                         MPI communicator, used with distributed-FFT backends
+     * \param[in]  gridSizesInXForEachRank      Number of grid points used with each rank in X-dimension
+     * \param[in]  gridSizesInYForEachRank      Number of grid points used with each rank in Y-dimension
+     * \param[in]  nz                           Grid dimension in Z
+     * \param[in]  performOutOfPlaceFFT         Whether the FFT will be performed out-of-place
+     * \param[in]  context                      GPU context.
+     * \param[in]  pmeStream                    GPU stream for PME.
+     * \param[in,out]  realGridSize             Dimensions of the local real grid, out if allocateGrids=true
+     * \param[in,out]  realGridSizePadded       Dimensions of the local real grid with padding, out if allocateGrids=true
+     * \param[in,out]  complexGridSizePadded    Dimensions of the local complex grid with padding, out if allocateGrids=true
+     * \param[in,out]  realGrid                 Device buffer of floats for the local real grid, out if allocateGrids=true
+     * \param[in,out]  complexGrid              Device buffer of complex floats for the local complex grid, out if allocateGrids=true
      */
-    Gpu3dFft(ivec                 realGridSize,
-             ivec                 realGridSizePadded,
-             ivec                 complexGridSizePadded,
-             bool                 useDecomposition,
+    Gpu3dFft(FftBackend           backend,
+             bool                 allocateGrids,
+             MPI_Comm             comm,
+             ArrayRef<const int>  gridSizesInXForEachRank,
+             ArrayRef<const int>  gridSizesInYForEachRank,
+             int                  nz,
              bool                 performOutOfPlaceFFT,
              const DeviceContext& context,
              const DeviceStream&  pmeStream,
-             DeviceBuffer<float>  realGrid,
-             DeviceBuffer<float>  complexGrid);
+             ivec                 realGridSize,
+             ivec                 realGridSizePadded,
+             ivec                 complexGridSizePadded,
+             DeviceBuffer<float>* realGrid,
+             DeviceBuffer<float>* complexGrid);
 
     /*! \brief Destroys the FFT plans. */
     ~Gpu3dFft();
@@ -97,6 +123,10 @@ public:
 
 private:
     class Impl;
+    class ImplCuFft;
+    class ImplOcl;
+    class ImplSycl;
+
     std::unique_ptr<Impl> impl_;
 };