Decouple PME GPU 3DFFT from PME GPU module
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gpu_3dfft.cu
index 80daa420202c9cee1f7f8abd20807ccae3e873c3..f547fc6bcc6289793e9e4ac1f74d2c2f7f293d4f 100644 (file)
 
 #include "pme_gpu_3dfft.h"
 
+#include <cufft.h>
+
+#include "gromacs/gpu_utils/device_stream.h"
 #include "gromacs/utility/fatalerror.h"
 #include "gromacs/utility/gmxassert.h"
 
-#include "pme.cuh"
-#include "pme_gpu_types.h"
-#include "pme_gpu_types_host.h"
-#include "pme_gpu_types_host_impl.h"
+class GpuParallel3dFft::Impl
+{
+public:
+    Impl(ivec                 realGridSize,
+         ivec                 realGridSizePadded,
+         ivec                 complexGridSizePadded,
+         bool                 useDecomposition,
+         bool                 performOutOfPlaceFFT,
+         const DeviceContext& context,
+         const DeviceStream&  pmeStream,
+         DeviceBuffer<float>  realGrid,
+         DeviceBuffer<float>  complexGrid);
+    ~Impl();
+
+    cufftHandle   planR2C_;
+    cufftHandle   planC2R_;
+    cufftReal*    realGrid_;
+    cufftComplex* complexGrid_;
+};
 
 static void handleCufftError(cufftResult_t status, const char* msg)
 {
@@ -60,28 +78,26 @@ static void handleCufftError(cufftResult_t status, const char* msg)
     }
 }
 
-GpuParallel3dFft::GpuParallel3dFft(const PmeGpu* pmeGpu, const int gridIndex)
+GpuParallel3dFft::Impl::Impl(ivec       realGridSize,
+                             ivec       realGridSizePadded,
+                             ivec       complexGridSizePadded,
+                             const bool useDecomposition,
+                             const bool /*performOutOfPlaceFFT*/,
+                             const DeviceContext& /*context*/,
+                             const DeviceStream& pmeStream,
+                             DeviceBuffer<float> realGrid,
+                             DeviceBuffer<float> complexGrid) :
+    realGrid_(reinterpret_cast<cufftReal*>(realGrid)),
+    complexGrid_(reinterpret_cast<cufftComplex*>(complexGrid))
 {
-    const PmeGpuCudaKernelParams* kernelParamsPtr = pmeGpu->kernelParams.get();
-    ivec                          realGridSize, realGridSizePadded, complexGridSizePadded;
-    for (int i = 0; i < DIM; i++)
-    {
-        realGridSize[i]          = kernelParamsPtr->grid.realGridSize[i];
-        realGridSizePadded[i]    = kernelParamsPtr->grid.realGridSizePadded[i];
-        complexGridSizePadded[i] = kernelParamsPtr->grid.complexGridSizePadded[i];
-    }
-
-    GMX_RELEASE_ASSERT(!pme_gpu_settings(pmeGpu).useDecomposition,
-                       "FFT decomposition not implemented");
+    GMX_RELEASE_ASSERT(!useDecomposition, "FFT decomposition not implemented");
 
     const int complexGridSizePaddedTotal =
             complexGridSizePadded[XX] * complexGridSizePadded[YY] * complexGridSizePadded[ZZ];
     const int realGridSizePaddedTotal =
             realGridSizePadded[XX] * realGridSizePadded[YY] * realGridSizePadded[ZZ];
 
-    realGrid_ = reinterpret_cast<cufftReal*>(kernelParamsPtr->grid.d_realGrid[gridIndex]);
     GMX_RELEASE_ASSERT(realGrid_, "Bad (null) input real-space grid");
-    complexGrid_ = reinterpret_cast<cufftComplex*>(kernelParamsPtr->grid.d_fourierGrid[gridIndex]);
     GMX_RELEASE_ASSERT(complexGrid_, "Bad (null) input complex grid");
 
     cufftResult_t result;
@@ -121,8 +137,8 @@ GpuParallel3dFft::GpuParallel3dFft(const PmeGpu* pmeGpu, const int gridIndex)
                            batch);
     handleCufftError(result, "cufftPlanMany C2R plan failure");
 
-    cudaStream_t stream = pmeGpu->archSpecific->pmeStream_.stream();
-    GMX_RELEASE_ASSERT(stream, "Using the default CUDA stream for PME cuFFT");
+    cudaStream_t stream = pmeStream.stream();
+    GMX_RELEASE_ASSERT(stream, "Can not use the default CUDA stream for PME cuFFT");
 
     result = cufftSetStream(planR2C_, stream);
     handleCufftError(result, "cufftSetStream R2C failure");
@@ -131,7 +147,7 @@ GpuParallel3dFft::GpuParallel3dFft(const PmeGpu* pmeGpu, const int gridIndex)
     handleCufftError(result, "cufftSetStream C2R failure");
 }
 
-GpuParallel3dFft::~GpuParallel3dFft()
+GpuParallel3dFft::Impl::~Impl()
 {
     cufftResult_t result;
     result = cufftDestroy(planR2C_);
@@ -145,12 +161,35 @@ void GpuParallel3dFft::perform3dFft(gmx_fft_direction dir, CommandEvent* /*timin
     cufftResult_t result;
     if (dir == GMX_FFT_REAL_TO_COMPLEX)
     {
-        result = cufftExecR2C(planR2C_, realGrid_, complexGrid_);
+        result = cufftExecR2C(impl_->planR2C_, impl_->realGrid_, impl_->complexGrid_);
         handleCufftError(result, "cuFFT R2C execution failure");
     }
     else
     {
-        result = cufftExecC2R(planC2R_, complexGrid_, realGrid_);
+        result = cufftExecC2R(impl_->planC2R_, impl_->complexGrid_, impl_->realGrid_);
         handleCufftError(result, "cuFFT C2R execution failure");
     }
 }
+
+GpuParallel3dFft::GpuParallel3dFft(ivec                 realGridSize,
+                                   ivec                 realGridSizePadded,
+                                   ivec                 complexGridSizePadded,
+                                   const bool           useDecomposition,
+                                   const bool           performOutOfPlaceFFT,
+                                   const DeviceContext& context,
+                                   const DeviceStream&  pmeStream,
+                                   DeviceBuffer<float>  realGrid,
+                                   DeviceBuffer<float>  complexGrid) :
+    impl_(std::make_unique<Impl>(realGridSize,
+                                 realGridSizePadded,
+                                 complexGridSizePadded,
+                                 useDecomposition,
+                                 performOutOfPlaceFFT,
+                                 context,
+                                 pmeStream,
+                                 realGrid,
+                                 complexGrid))
+{
+}
+
+GpuParallel3dFft::~GpuParallel3dFft() = default;