#include "pme_gpu_3dfft.h"
+#include <cufft.h>
+
+#include "gromacs/gpu_utils/device_stream.h"
#include "gromacs/utility/fatalerror.h"
#include "gromacs/utility/gmxassert.h"
-#include "pme.cuh"
-#include "pme_gpu_types.h"
-#include "pme_gpu_types_host.h"
-#include "pme_gpu_types_host_impl.h"
+class GpuParallel3dFft::Impl
+{
+public:
+ Impl(ivec realGridSize,
+ ivec realGridSizePadded,
+ ivec complexGridSizePadded,
+ bool useDecomposition,
+ bool performOutOfPlaceFFT,
+ const DeviceContext& context,
+ const DeviceStream& pmeStream,
+ DeviceBuffer<float> realGrid,
+ DeviceBuffer<float> complexGrid);
+ ~Impl();
+
+ cufftHandle planR2C_;
+ cufftHandle planC2R_;
+ cufftReal* realGrid_;
+ cufftComplex* complexGrid_;
+};
static void handleCufftError(cufftResult_t status, const char* msg)
{
}
}
-GpuParallel3dFft::GpuParallel3dFft(const PmeGpu* pmeGpu, const int gridIndex)
+GpuParallel3dFft::Impl::Impl(ivec realGridSize,
+ ivec realGridSizePadded,
+ ivec complexGridSizePadded,
+ const bool useDecomposition,
+ const bool /*performOutOfPlaceFFT*/,
+ const DeviceContext& /*context*/,
+ const DeviceStream& pmeStream,
+ DeviceBuffer<float> realGrid,
+ DeviceBuffer<float> complexGrid) :
+ realGrid_(reinterpret_cast<cufftReal*>(realGrid)),
+ complexGrid_(reinterpret_cast<cufftComplex*>(complexGrid))
{
- const PmeGpuCudaKernelParams* kernelParamsPtr = pmeGpu->kernelParams.get();
- ivec realGridSize, realGridSizePadded, complexGridSizePadded;
- for (int i = 0; i < DIM; i++)
- {
- realGridSize[i] = kernelParamsPtr->grid.realGridSize[i];
- realGridSizePadded[i] = kernelParamsPtr->grid.realGridSizePadded[i];
- complexGridSizePadded[i] = kernelParamsPtr->grid.complexGridSizePadded[i];
- }
-
- GMX_RELEASE_ASSERT(!pme_gpu_settings(pmeGpu).useDecomposition,
- "FFT decomposition not implemented");
+ GMX_RELEASE_ASSERT(!useDecomposition, "FFT decomposition not implemented");
const int complexGridSizePaddedTotal =
complexGridSizePadded[XX] * complexGridSizePadded[YY] * complexGridSizePadded[ZZ];
const int realGridSizePaddedTotal =
realGridSizePadded[XX] * realGridSizePadded[YY] * realGridSizePadded[ZZ];
- realGrid_ = reinterpret_cast<cufftReal*>(kernelParamsPtr->grid.d_realGrid[gridIndex]);
GMX_RELEASE_ASSERT(realGrid_, "Bad (null) input real-space grid");
- complexGrid_ = reinterpret_cast<cufftComplex*>(kernelParamsPtr->grid.d_fourierGrid[gridIndex]);
GMX_RELEASE_ASSERT(complexGrid_, "Bad (null) input complex grid");
cufftResult_t result;
batch);
handleCufftError(result, "cufftPlanMany C2R plan failure");
- cudaStream_t stream = pmeGpu->archSpecific->pmeStream_.stream();
- GMX_RELEASE_ASSERT(stream, "Using the default CUDA stream for PME cuFFT");
+ cudaStream_t stream = pmeStream.stream();
+ GMX_RELEASE_ASSERT(stream, "Can not use the default CUDA stream for PME cuFFT");
result = cufftSetStream(planR2C_, stream);
handleCufftError(result, "cufftSetStream R2C failure");
handleCufftError(result, "cufftSetStream C2R failure");
}
-GpuParallel3dFft::~GpuParallel3dFft()
+GpuParallel3dFft::Impl::~Impl()
{
cufftResult_t result;
result = cufftDestroy(planR2C_);
cufftResult_t result;
if (dir == GMX_FFT_REAL_TO_COMPLEX)
{
- result = cufftExecR2C(planR2C_, realGrid_, complexGrid_);
+ result = cufftExecR2C(impl_->planR2C_, impl_->realGrid_, impl_->complexGrid_);
handleCufftError(result, "cuFFT R2C execution failure");
}
else
{
- result = cufftExecC2R(planC2R_, complexGrid_, realGrid_);
+ result = cufftExecC2R(impl_->planC2R_, impl_->complexGrid_, impl_->realGrid_);
handleCufftError(result, "cuFFT C2R execution failure");
}
}
+
+GpuParallel3dFft::GpuParallel3dFft(ivec realGridSize,
+ ivec realGridSizePadded,
+ ivec complexGridSizePadded,
+ const bool useDecomposition,
+ const bool performOutOfPlaceFFT,
+ const DeviceContext& context,
+ const DeviceStream& pmeStream,
+ DeviceBuffer<float> realGrid,
+ DeviceBuffer<float> complexGrid) :
+ impl_(std::make_unique<Impl>(realGridSize,
+ realGridSizePadded,
+ complexGridSizePadded,
+ useDecomposition,
+ performOutOfPlaceFFT,
+ context,
+ pmeStream,
+ realGrid,
+ complexGrid))
+{
+}
+
+GpuParallel3dFft::~GpuParallel3dFft() = default;