*
* \author Aleksei Iupinov <a.yupinov@gmail.com>
* \author Mark Abraham <mark.j.abraham@gmail.com>
+ * \author Gaurav Garg <gaugarg@nvidia.com>
* \ingroup module_fft
*/
#include "gromacs/fft/fft.h"
#include "gromacs/gpu_utils/devicebuffer_datatype.h"
#include "gromacs/gpu_utils/gputraits.h"
+#include "gromacs/utility/gmxmpi.h"
class DeviceContext;
class DeviceStream;
namespace gmx
{
+template<typename T>
+class ArrayRef;
+
+/*! \internal \brief
+ * Enum specifying all GPU FFT backends supported by GROMACS
+ * Some of the backends support only single GPU, some only multi-node, multi-GPU
+ */
+enum class FftBackend
+{
+ Cufft, // supports only single-GPU
+ Ocl, // supports only single-GPU
+ Sycl, // Not supported currently
+ Count
+};
+
/*! \internal \brief
* A 3D FFT class for performing R2C/C2R transforms
- * \todo Make this class actually parallel over multiple GPUs
*/
class Gpu3dFft
{
public:
/*! \brief
- * Constructs GPU FFT plans for performing 3D FFT on a PME grid.
+ * Construct 3D FFT object for given backend
*
- * \param[in] realGridSize Dimensions of the real grid
- * \param[in] realGridSizePadded Dimensions of the real grid with padding
- * \param[in] complexGridSizePadded Dimensions of the real grid with padding
- * \param[in] useDecomposition Whether PME decomposition will be used
- * \param[in] performOutOfPlaceFFT Whether the FFT will be performed out-of-place
- * \param[in] context GPU context.
- * \param[in] pmeStream GPU stream for PME.
- * \param[in] realGrid Device buffer of floats for the real grid
- * \param[in] complexGrid Device buffer of complex floats for the complex grid
+ * \param[in] backend FFT backend to be instantiated
+ * \param[in] allocateGrids True if fft grids are to be allocated, false if pre-allocated
+ * \param[in] comm MPI communicator, used with distributed-FFT backends
+ * \param[in] gridSizesInXForEachRank Number of grid points used with each rank in X-dimension
+ * \param[in] gridSizesInYForEachRank Number of grid points used with each rank in Y-dimension
+ * \param[in] nz Grid dimension in Z
+ * \param[in] performOutOfPlaceFFT Whether the FFT will be performed out-of-place
+ * \param[in] context GPU context.
+ * \param[in] pmeStream GPU stream for PME.
+ * \param[in,out] realGridSize Dimensions of the local real grid, out if allocateGrids=true
+ * \param[in,out] realGridSizePadded Dimensions of the local real grid with padding, out if allocateGrids=true
+ * \param[in,out] complexGridSizePadded Dimensions of the local complex grid with padding, out if allocateGrids=true
+ * \param[in,out] realGrid Device buffer of floats for the local real grid, out if allocateGrids=true
+ * \param[in,out] complexGrid Device buffer of complex floats for the local complex grid, out if allocateGrids=true
*/
- Gpu3dFft(ivec realGridSize,
- ivec realGridSizePadded,
- ivec complexGridSizePadded,
- bool useDecomposition,
+ Gpu3dFft(FftBackend backend,
+ bool allocateGrids,
+ MPI_Comm comm,
+ ArrayRef<const int> gridSizesInXForEachRank,
+ ArrayRef<const int> gridSizesInYForEachRank,
+ int nz,
bool performOutOfPlaceFFT,
const DeviceContext& context,
const DeviceStream& pmeStream,
- DeviceBuffer<float> realGrid,
- DeviceBuffer<float> complexGrid);
+ ivec realGridSize,
+ ivec realGridSizePadded,
+ ivec complexGridSizePadded,
+ DeviceBuffer<float>* realGrid,
+ DeviceBuffer<float>* complexGrid);
/*! \brief Destroys the FFT plans. */
~Gpu3dFft();
private:
class Impl;
+ class ImplCuFft;
+ class ImplOcl;
+ class ImplSycl;
+
std::unique_ptr<Impl> impl_;
};