Decouple PME GPU 3DFFT from PME GPU module
authorMark Abraham <mark.j.abraham@gmail.com>
Sat, 26 Jun 2021 16:48:46 +0000 (16:48 +0000)
committerArtem Zhmurov <zhmurov@gmail.com>
Sat, 26 Jun 2021 16:48:46 +0000 (16:48 +0000)
src/gromacs/ewald/CMakeLists.txt
src/gromacs/ewald/pme_gpu_3dfft.cu
src/gromacs/ewald/pme_gpu_3dfft.h
src/gromacs/ewald/pme_gpu_3dfft_ocl.cpp
src/gromacs/ewald/pme_gpu_3dfft_sycl.cpp [new file with mode: 0644]
src/gromacs/ewald/pme_gpu_internal.cpp
src/gromacs/ewald/pme_gpu_sycl_stubs.cpp

index ac50cc1a4811948787c6e9366f7b71efa2712e9c..f9041d8d2b17c15bdb560835e4740bf6d577dc84 100644 (file)
@@ -92,6 +92,7 @@ elseif (GMX_GPU_SYCL)
     gmx_add_libgromacs_sources(
         # Files that implement stubs
         pme_gpu_sycl_stubs.cpp
+        pme_gpu_3dfft_sycl.cpp
         # GPU-specific sources
         pme_gpu.cpp
         pme_gpu_internal.cpp
@@ -101,6 +102,7 @@ elseif (GMX_GPU_SYCL)
         pme_gpu_internal.cpp
         pme_gpu_program.cpp
         pme_gpu_sycl_stubs.cpp
+        pme_gpu_3dfft_sycl.cpp
         pme_gpu_timings.cpp
         )
 else()
index 80daa420202c9cee1f7f8abd20807ccae3e873c3..f547fc6bcc6289793e9e4ac1f74d2c2f7f293d4f 100644 (file)
 
 #include "pme_gpu_3dfft.h"
 
+#include <cufft.h>
+
+#include "gromacs/gpu_utils/device_stream.h"
 #include "gromacs/utility/fatalerror.h"
 #include "gromacs/utility/gmxassert.h"
 
-#include "pme.cuh"
-#include "pme_gpu_types.h"
-#include "pme_gpu_types_host.h"
-#include "pme_gpu_types_host_impl.h"
+class GpuParallel3dFft::Impl
+{
+public:
+    Impl(ivec                 realGridSize,
+         ivec                 realGridSizePadded,
+         ivec                 complexGridSizePadded,
+         bool                 useDecomposition,
+         bool                 performOutOfPlaceFFT,
+         const DeviceContext& context,
+         const DeviceStream&  pmeStream,
+         DeviceBuffer<float>  realGrid,
+         DeviceBuffer<float>  complexGrid);
+    ~Impl();
+
+    cufftHandle   planR2C_;
+    cufftHandle   planC2R_;
+    cufftReal*    realGrid_;
+    cufftComplex* complexGrid_;
+};
 
 static void handleCufftError(cufftResult_t status, const char* msg)
 {
@@ -60,28 +78,26 @@ static void handleCufftError(cufftResult_t status, const char* msg)
     }
 }
 
-GpuParallel3dFft::GpuParallel3dFft(const PmeGpu* pmeGpu, const int gridIndex)
+GpuParallel3dFft::Impl::Impl(ivec       realGridSize,
+                             ivec       realGridSizePadded,
+                             ivec       complexGridSizePadded,
+                             const bool useDecomposition,
+                             const bool /*performOutOfPlaceFFT*/,
+                             const DeviceContext& /*context*/,
+                             const DeviceStream& pmeStream,
+                             DeviceBuffer<float> realGrid,
+                             DeviceBuffer<float> complexGrid) :
+    realGrid_(reinterpret_cast<cufftReal*>(realGrid)),
+    complexGrid_(reinterpret_cast<cufftComplex*>(complexGrid))
 {
-    const PmeGpuCudaKernelParams* kernelParamsPtr = pmeGpu->kernelParams.get();
-    ivec                          realGridSize, realGridSizePadded, complexGridSizePadded;
-    for (int i = 0; i < DIM; i++)
-    {
-        realGridSize[i]          = kernelParamsPtr->grid.realGridSize[i];
-        realGridSizePadded[i]    = kernelParamsPtr->grid.realGridSizePadded[i];
-        complexGridSizePadded[i] = kernelParamsPtr->grid.complexGridSizePadded[i];
-    }
-
-    GMX_RELEASE_ASSERT(!pme_gpu_settings(pmeGpu).useDecomposition,
-                       "FFT decomposition not implemented");
+    GMX_RELEASE_ASSERT(!useDecomposition, "FFT decomposition not implemented");
 
     const int complexGridSizePaddedTotal =
             complexGridSizePadded[XX] * complexGridSizePadded[YY] * complexGridSizePadded[ZZ];
     const int realGridSizePaddedTotal =
             realGridSizePadded[XX] * realGridSizePadded[YY] * realGridSizePadded[ZZ];
 
-    realGrid_ = reinterpret_cast<cufftReal*>(kernelParamsPtr->grid.d_realGrid[gridIndex]);
     GMX_RELEASE_ASSERT(realGrid_, "Bad (null) input real-space grid");
-    complexGrid_ = reinterpret_cast<cufftComplex*>(kernelParamsPtr->grid.d_fourierGrid[gridIndex]);
     GMX_RELEASE_ASSERT(complexGrid_, "Bad (null) input complex grid");
 
     cufftResult_t result;
@@ -121,8 +137,8 @@ GpuParallel3dFft::GpuParallel3dFft(const PmeGpu* pmeGpu, const int gridIndex)
                            batch);
     handleCufftError(result, "cufftPlanMany C2R plan failure");
 
-    cudaStream_t stream = pmeGpu->archSpecific->pmeStream_.stream();
-    GMX_RELEASE_ASSERT(stream, "Using the default CUDA stream for PME cuFFT");
+    cudaStream_t stream = pmeStream.stream();
+    GMX_RELEASE_ASSERT(stream, "Can not use the default CUDA stream for PME cuFFT");
 
     result = cufftSetStream(planR2C_, stream);
     handleCufftError(result, "cufftSetStream R2C failure");
@@ -131,7 +147,7 @@ GpuParallel3dFft::GpuParallel3dFft(const PmeGpu* pmeGpu, const int gridIndex)
     handleCufftError(result, "cufftSetStream C2R failure");
 }
 
-GpuParallel3dFft::~GpuParallel3dFft()
+GpuParallel3dFft::Impl::~Impl()
 {
     cufftResult_t result;
     result = cufftDestroy(planR2C_);
@@ -145,12 +161,35 @@ void GpuParallel3dFft::perform3dFft(gmx_fft_direction dir, CommandEvent* /*timin
     cufftResult_t result;
     if (dir == GMX_FFT_REAL_TO_COMPLEX)
     {
-        result = cufftExecR2C(planR2C_, realGrid_, complexGrid_);
+        result = cufftExecR2C(impl_->planR2C_, impl_->realGrid_, impl_->complexGrid_);
         handleCufftError(result, "cuFFT R2C execution failure");
     }
     else
     {
-        result = cufftExecC2R(planC2R_, complexGrid_, realGrid_);
+        result = cufftExecC2R(impl_->planC2R_, impl_->complexGrid_, impl_->realGrid_);
         handleCufftError(result, "cuFFT C2R execution failure");
     }
 }
+
+GpuParallel3dFft::GpuParallel3dFft(ivec                 realGridSize,
+                                   ivec                 realGridSizePadded,
+                                   ivec                 complexGridSizePadded,
+                                   const bool           useDecomposition,
+                                   const bool           performOutOfPlaceFFT,
+                                   const DeviceContext& context,
+                                   const DeviceStream&  pmeStream,
+                                   DeviceBuffer<float>  realGrid,
+                                   DeviceBuffer<float>  complexGrid) :
+    impl_(std::make_unique<Impl>(realGridSize,
+                                 realGridSizePadded,
+                                 complexGridSizePadded,
+                                 useDecomposition,
+                                 performOutOfPlaceFFT,
+                                 context,
+                                 pmeStream,
+                                 realGrid,
+                                 complexGrid))
+{
+}
+
+GpuParallel3dFft::~GpuParallel3dFft() = default;
index d71e43522ca59c73849d8c677c6e93393c186aa6..5939dea0c685258ef177150aa02ed433edc5dc13 100644 (file)
 #ifndef GMX_EWALD_PME_GPU_3DFFT_H
 #define GMX_EWALD_PME_GPU_3DFFT_H
 
-#include "config.h"
+#include <memory>
 
-#include <vector>
-
-#if GMX_GPU_CUDA
-#    include <cufft.h>
-
-#    include "gromacs/gpu_utils/gputraits.cuh"
-#elif GMX_GPU_OPENCL
-#    include <clFFT.h>
-
-#    include "gromacs/gpu_utils/gmxopencl.h"
-#    include "gromacs/gpu_utils/gputraits_ocl.h"
-#elif GMX_GPU_SYCL
-#    include "gromacs/gpu_utils/gputraits_sycl.h"
-#endif
-
-#include "gromacs/fft/fft.h" // for the enum gmx_fft_direction
+#include "gromacs/fft/fft.h"
+#include "gromacs/gpu_utils/devicebuffer_datatype.h"
+#include "gromacs/gpu_utils/gputraits.h"
 
+class DeviceContext;
+class DeviceStream;
 struct PmeGpu;
 
 /*! \internal \brief
@@ -73,12 +62,28 @@ class GpuParallel3dFft
 {
 public:
     /*! \brief
-     * Constructs CUDA/OpenCL FFT plans for performing 3D FFT on a PME grid.
+     * Constructs GPU FFT plans for performing 3D FFT on a PME grid.
      *
-     * \param[in] pmeGpu                  The PME GPU structure.
-     * \param[in] gridIndex               The index of the grid on which to perform the calculations.
+     * \param[in]  realGridSize           Dimensions of the real grid
+     * \param[in]  realGridSizePadded     Dimensions of the real grid with padding
+     * \param[in]  complexGridSizePadded  Dimensions of the real grid with padding
+     * \param[in]  useDecomposition       Whether PME decomposition will be used
+     * \param[in]  performOutOfPlaceFFT   Whether the FFT will be performed out-of-place
+     * \param[in]  context                GPU context.
+     * \param[in]  pmeStream              GPU stream for PME.
+     * \param[in]  realGrid               Device buffer of floats for the real grid
+     * \param[in]  complexGrid            Device buffer of complex floats for the complex grid
      */
-    GpuParallel3dFft(const PmeGpu* pmeGpu, int gridIndex);
+    GpuParallel3dFft(ivec                 realGridSize,
+                     ivec                 realGridSizePadded,
+                     ivec                 complexGridSizePadded,
+                     bool                 useDecomposition,
+                     bool                 performOutOfPlaceFFT,
+                     const DeviceContext& context,
+                     const DeviceStream&  pmeStream,
+                     DeviceBuffer<float>  realGrid,
+                     DeviceBuffer<float>  complexGrid);
+
     /*! \brief Destroys the FFT plans. */
     ~GpuParallel3dFft();
     /*! \brief Performs the FFT transform in given direction
@@ -89,18 +94,8 @@ public:
     void perform3dFft(gmx_fft_direction dir, CommandEvent* timingEvent);
 
 private:
-#if GMX_GPU_CUDA
-    cufftHandle   planR2C_;
-    cufftHandle   planC2R_;
-    cufftReal*    realGrid_;
-    cufftComplex* complexGrid_;
-#elif GMX_GPU_OPENCL
-    clfftPlanHandle               planR2C_;
-    clfftPlanHandle               planC2R_;
-    std::vector<cl_command_queue> deviceStreams_;
-    cl_mem                        realGrid_;
-    cl_mem                        complexGrid_;
-#endif
+    class Impl;
+    std::unique_ptr<Impl> impl_;
 };
 
 #endif
index f9a5f11f060e59d3426381741af4885adfb10567..d63901d3197d2556785d9b1721e08dcd982c441f 100644 (file)
@@ -1,7 +1,7 @@
 /*
  * This file is part of the GROMACS molecular simulation package.
  *
- * Copyright (c) 2016,2017,2018,2019,2020, by the GROMACS development team, led by
+ * Copyright (c) 2016,2017,2018,2019,2020,2021, by the GROMACS development team, led by
  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
  * and including many others, as listed in the AUTHORS file in the
  * top-level source directory and at http://www.gromacs.org.
 
 #include "gmxpre.h"
 
+#include "pme_gpu_3dfft.h"
+
 #include <array>
+#include <vector>
 
+#include <clFFT.h>
+
+#include "gromacs/gpu_utils/device_context.h"
+#include "gromacs/gpu_utils/device_stream.h"
+#include "gromacs/gpu_utils/gmxopencl.h"
 #include "gromacs/utility/exceptions.h"
 #include "gromacs/utility/gmxassert.h"
 #include "gromacs/utility/stringutil.h"
 
-#include "pme_gpu_3dfft.h"
-#include "pme_gpu_internal.h"
-#include "pme_gpu_types.h"
-#include "pme_gpu_types_host_impl.h"
+class GpuParallel3dFft::Impl
+{
+public:
+    Impl(ivec                 realGridSize,
+         ivec                 realGridSizePadded,
+         ivec                 complexGridSizePadded,
+         bool                 useDecomposition,
+         bool                 performOutOfPlaceFFT,
+         const DeviceContext& context,
+         const DeviceStream&  pmeStream,
+         DeviceBuffer<float>  realGrid,
+         DeviceBuffer<float>  complexGrid);
+    ~Impl();
+
+    clfftPlanHandle               planR2C_;
+    clfftPlanHandle               planC2R_;
+    std::vector<cl_command_queue> commandStreams_;
+    cl_mem                        realGrid_;
+    cl_mem                        complexGrid_;
+};
 
 //! Throws the exception on clFFT error
 static void handleClfftError(clfftStatus status, const char* msg)
@@ -63,41 +87,35 @@ static void handleClfftError(clfftStatus status, const char* msg)
     }
 }
 
-GpuParallel3dFft::GpuParallel3dFft(const PmeGpu* pmeGpu, const int gridIndex)
+GpuParallel3dFft::Impl::Impl(ivec                 realGridSize,
+                             ivec                 realGridSizePadded,
+                             ivec                 complexGridSizePadded,
+                             const bool           useDecomposition,
+                             const bool           performOutOfPlaceFFT,
+                             const DeviceContext& context,
+                             const DeviceStream&  pmeStream,
+                             DeviceBuffer<float>  realGrid,
+                             DeviceBuffer<float>  complexGrid) :
+    realGrid_(realGrid), complexGrid_(complexGrid)
 {
-    // Extracting all the data from PME GPU
-    std::array<size_t, DIM> realGridSize, realGridSizePadded, complexGridSizePadded;
-
-    GMX_RELEASE_ASSERT(!pme_gpu_settings(pmeGpu).useDecomposition,
-                       "FFT decomposition not implemented");
-    PmeGpuKernelParamsBase* kernelParamsPtr = pmeGpu->kernelParams.get();
-    for (int i = 0; i < DIM; i++)
-    {
-        realGridSize[i]          = kernelParamsPtr->grid.realGridSize[i];
-        realGridSizePadded[i]    = kernelParamsPtr->grid.realGridSizePadded[i];
-        complexGridSizePadded[i] = kernelParamsPtr->grid.complexGridSizePadded[i];
-        GMX_ASSERT(kernelParamsPtr->grid.complexGridSizePadded[i]
-                           == kernelParamsPtr->grid.complexGridSize[i],
-                   "Complex padding not implemented");
-    }
-    cl_context context = pmeGpu->archSpecific->deviceContext_.context();
-    deviceStreams_.push_back(pmeGpu->archSpecific->pmeStream_.stream());
-    realGrid_                       = kernelParamsPtr->grid.d_realGrid[gridIndex];
-    complexGrid_                    = kernelParamsPtr->grid.d_fourierGrid[gridIndex];
-    const bool performOutOfPlaceFFT = pmeGpu->archSpecific->performOutOfPlaceFFT;
+    GMX_RELEASE_ASSERT(!useDecomposition, "FFT decomposition not implemented");
 
+    cl_context clContext = context.context();
+    commandStreams_.push_back(pmeStream.stream());
 
     // clFFT expects row-major, so dimensions/strides are reversed (ZYX instead of XYZ)
-    std::array<size_t, DIM> realGridDimensions = { realGridSize[ZZ], realGridSize[YY], realGridSize[XX] };
-    std::array<size_t, DIM> realGridStrides    = { 1,
-                                                realGridSizePadded[ZZ],
-                                                realGridSizePadded[YY] * realGridSizePadded[ZZ] };
+    std::array<size_t, DIM> realGridDimensions = { size_t(realGridSize[ZZ]),
+                                                   size_t(realGridSize[YY]),
+                                                   size_t(realGridSize[XX]) };
+    std::array<size_t, DIM> realGridStrides    = {
+        1, size_t(realGridSizePadded[ZZ]), size_t(realGridSizePadded[YY] * realGridSizePadded[ZZ])
+    };
     std::array<size_t, DIM> complexGridStrides = {
-        1, complexGridSizePadded[ZZ], complexGridSizePadded[YY] * complexGridSizePadded[ZZ]
+        1, size_t(complexGridSizePadded[ZZ]), size_t(complexGridSizePadded[YY] * complexGridSizePadded[ZZ])
     };
 
     constexpr clfftDim dims = CLFFT_3D;
-    handleClfftError(clfftCreateDefaultPlan(&planR2C_, context, dims, realGridDimensions.data()),
+    handleClfftError(clfftCreateDefaultPlan(&planR2C_, clContext, dims, realGridDimensions.data()),
                      "clFFT planning failure");
     handleClfftError(clfftSetResultLocation(planR2C_, performOutOfPlaceFFT ? CLFFT_OUTOFPLACE : CLFFT_INPLACE),
                      "clFFT planning failure");
@@ -109,7 +127,7 @@ GpuParallel3dFft::GpuParallel3dFft(const PmeGpu* pmeGpu, const int gridIndex)
                      "clFFT coefficient setup failure");
 
     // The only difference between 2 plans is direction
-    handleClfftError(clfftCopyPlan(&planC2R_, context, planR2C_), "clFFT plan copying failure");
+    handleClfftError(clfftCopyPlan(&planC2R_, clContext, planR2C_), "clFFT plan copying failure");
 
     handleClfftError(clfftSetLayout(planR2C_, CLFFT_REAL, CLFFT_HERMITIAN_INTERLEAVED),
                      "clFFT R2C layout failure");
@@ -126,16 +144,16 @@ GpuParallel3dFft::GpuParallel3dFft(const PmeGpu* pmeGpu, const int gridIndex)
     handleClfftError(clfftSetPlanOutStride(planC2R_, dims, realGridStrides.data()),
                      "clFFT stride setting failure");
 
-    handleClfftError(clfftBakePlan(planR2C_, deviceStreams_.size(), deviceStreams_.data(), nullptr, nullptr),
+    handleClfftError(clfftBakePlan(planR2C_, commandStreams_.size(), commandStreams_.data(), nullptr, nullptr),
                      "clFFT precompiling failure");
-    handleClfftError(clfftBakePlan(planC2R_, deviceStreams_.size(), deviceStreams_.data(), nullptr, nullptr),
+    handleClfftError(clfftBakePlan(planC2R_, commandStreams_.size(), commandStreams_.data(), nullptr, nullptr),
                      "clFFT precompiling failure");
 
     // TODO: implement solve kernel as R2C FFT callback
     // TODO: disable last transpose (clfftSetPlanTransposeResult)
 }
 
-GpuParallel3dFft::~GpuParallel3dFft()
+GpuParallel3dFft::Impl::~Impl()
 {
     clfftDestroyPlan(&planR2C_);
     clfftDestroyPlan(&planC2R_);
@@ -153,16 +171,16 @@ void GpuParallel3dFft::perform3dFft(gmx_fft_direction dir, CommandEvent* timingE
     switch (dir)
     {
         case GMX_FFT_REAL_TO_COMPLEX:
-            plan        = planR2C_;
+            plan        = impl_->planR2C_;
             direction   = CLFFT_FORWARD;
-            inputGrids  = &realGrid_;
-            outputGrids = &complexGrid_;
+            inputGrids  = &impl_->realGrid_;
+            outputGrids = &impl_->complexGrid_;
             break;
         case GMX_FFT_COMPLEX_TO_REAL:
-            plan        = planC2R_;
+            plan        = impl_->planC2R_;
             direction   = CLFFT_BACKWARD;
-            inputGrids  = &complexGrid_;
-            outputGrids = &realGrid_;
+            inputGrids  = &impl_->complexGrid_;
+            outputGrids = &impl_->realGrid_;
             break;
         default:
             GMX_THROW(
@@ -170,8 +188,8 @@ void GpuParallel3dFft::perform3dFft(gmx_fft_direction dir, CommandEvent* timingE
     }
     handleClfftError(clfftEnqueueTransform(plan,
                                            direction,
-                                           deviceStreams_.size(),
-                                           deviceStreams_.data(),
+                                           impl_->commandStreams_.size(),
+                                           impl_->commandStreams_.data(),
                                            waitEvents.size(),
                                            waitEvents.data(),
                                            timingEvent,
@@ -180,3 +198,26 @@ void GpuParallel3dFft::perform3dFft(gmx_fft_direction dir, CommandEvent* timingE
                                            tempBuffer),
                      "clFFT execution failure");
 }
+
+GpuParallel3dFft::GpuParallel3dFft(ivec                 realGridSize,
+                                   ivec                 realGridSizePadded,
+                                   ivec                 complexGridSizePadded,
+                                   const bool           useDecomposition,
+                                   const bool           performOutOfPlaceFFT,
+                                   const DeviceContext& context,
+                                   const DeviceStream&  pmeStream,
+                                   DeviceBuffer<float>  realGrid,
+                                   DeviceBuffer<float>  complexGrid) :
+    impl_(std::make_unique<Impl>(realGridSize,
+                                 realGridSizePadded,
+                                 complexGridSizePadded,
+                                 useDecomposition,
+                                 performOutOfPlaceFFT,
+                                 context,
+                                 pmeStream,
+                                 realGrid,
+                                 complexGrid))
+{
+}
+
+GpuParallel3dFft::~GpuParallel3dFft() = default;
diff --git a/src/gromacs/ewald/pme_gpu_3dfft_sycl.cpp b/src/gromacs/ewald/pme_gpu_3dfft_sycl.cpp
new file mode 100644 (file)
index 0000000..213889b
--- /dev/null
@@ -0,0 +1,78 @@
+/*
+ * This file is part of the GROMACS molecular simulation package.
+ *
+ * Copyright (c) 2016,2017,2018,2019,2020,2021, by the GROMACS development team, led by
+ * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
+ * and including many others, as listed in the AUTHORS file in the
+ * top-level source directory and at http://www.gromacs.org.
+ *
+ * GROMACS is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public License
+ * as published by the Free Software Foundation; either version 2.1
+ * of the License, or (at your option) any later version.
+ *
+ * GROMACS is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with GROMACS; if not, see
+ * http://www.gnu.org/licenses, or write to the Free Software Foundation,
+ * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
+ *
+ * If you want to redistribute modifications to GROMACS, please
+ * consider that scientific software is very special. Version
+ * control is crucial - bugs must be traceable. We will be happy to
+ * consider code for inclusion in the official distribution, but
+ * derived work must not be called official GROMACS. Details are found
+ * in the README & COPYING files - if they are missing, get the
+ * official version at http://www.gromacs.org.
+ *
+ * To help us fund GROMACS development, we humbly ask that you cite
+ * the research papers on the package. Check out http://www.gromacs.org.
+ */
+
+/*! \internal \file
+ *  \brief Implements GPU 3D FFT routines for SYCL.
+ *
+ *  \author Andrey Alekseenko <al42and@gmail.com>
+ *  \author Mark Abraham <mark.j.abraham@gmail.com>
+ *  \ingroup module_ewald
+ */
+
+#include "gmxpre.h"
+
+#include "pme_gpu_3dfft.h"
+
+#include "gromacs/utility/exceptions.h"
+
+// [[noreturn]] attributes must be added in the common headers, so it's easier to silence the warning here
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wmissing-noreturn"
+
+class GpuParallel3dFft::Impl
+{
+};
+
+GpuParallel3dFft::GpuParallel3dFft(ivec /*realGridSize*/,
+                                   ivec /*realGridSizePadded*/,
+                                   ivec /*complexGridSizePadded*/,
+                                   const bool /*useDecomposition*/,
+                                   const bool /*performOutOfPlaceFFT*/,
+                                   const DeviceContext& /*context*/,
+                                   const DeviceStream& /*pmeStream*/,
+                                   DeviceBuffer<float> /*realGrid*/,
+                                   DeviceBuffer<float> /*complexGrid*/)
+{
+    GMX_THROW(gmx::NotImplementedError("PME is not implemented in SYCL"));
+}
+
+GpuParallel3dFft::~GpuParallel3dFft() = default;
+
+void GpuParallel3dFft::perform3dFft(gmx_fft_direction /*dir*/, CommandEvent* /*timingEvent*/)
+{
+    GMX_THROW(gmx::NotImplementedError("Not implemented on SYCL yet"));
+}
+
+#pragma clang diagnostic pop
index 328e0c11f84eab784f4b21b222a534981a735bc1..3561c2002570df7cdc27d4aea6d61e6d876fc5d4 100644 (file)
@@ -606,9 +606,21 @@ void pme_gpu_reinit_3dfft(const PmeGpu* pmeGpu)
     if (pme_gpu_settings(pmeGpu).performGPUFFT)
     {
         pmeGpu->archSpecific->fftSetup.resize(0);
+        const bool        useDecomposition     = pme_gpu_settings(pmeGpu).useDecomposition;
+        const bool        performOutOfPlaceFFT = pmeGpu->archSpecific->performOutOfPlaceFFT;
+        PmeGpuGridParams& grid                 = pme_gpu_get_kernel_params_base_ptr(pmeGpu)->grid;
         for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
         {
-            pmeGpu->archSpecific->fftSetup.push_back(std::make_unique<GpuParallel3dFft>(pmeGpu, gridIndex));
+            pmeGpu->archSpecific->fftSetup.push_back(
+                    std::make_unique<GpuParallel3dFft>(grid.realGridSize,
+                                                       grid.realGridSizePadded,
+                                                       grid.complexGridSizePadded,
+                                                       useDecomposition,
+                                                       performOutOfPlaceFFT,
+                                                       pmeGpu->archSpecific->deviceContext_,
+                                                       pmeGpu->archSpecific->pmeStream_,
+                                                       grid.d_realGrid[gridIndex],
+                                                       grid.d_fourierGrid[gridIndex]));
         }
     }
 }
index c6f6c1c303132c330acc96d9331386d5a5fd2554..832395e084e2c7eed751da891b6a6521f6c57051 100644 (file)
@@ -44,8 +44,6 @@
 
 #include "gromacs/ewald/ewald_utils.h"
 
-#include "pme_gpu_3dfft.h"
-#include "pme_gpu_internal.h"
 #include "pme_gpu_program_impl.h"
 
 PmeGpuProgramImpl::PmeGpuProgramImpl(const DeviceContext& deviceContext) :
@@ -59,21 +57,3 @@ PmeGpuProgramImpl::PmeGpuProgramImpl(const DeviceContext& deviceContext) :
 }
 
 PmeGpuProgramImpl::~PmeGpuProgramImpl() = default;
-
-// [[noreturn]] attributes must be added in the common headers, so it's easier to silence the warning here
-#pragma clang diagnostic push
-#pragma clang diagnostic ignored "-Wmissing-noreturn"
-
-GpuParallel3dFft::GpuParallel3dFft(PmeGpu const* /*pmeGpu*/, int /*gridIndex*/)
-{
-    GMX_THROW(gmx::NotImplementedError("PME is not implemented in SYCL"));
-}
-
-GpuParallel3dFft::~GpuParallel3dFft() = default;
-
-void GpuParallel3dFft::perform3dFft(gmx_fft_direction /*dir*/, CommandEvent* /*timingEvent*/)
-{
-    GMX_THROW(gmx::NotImplementedError("Not implemented on SYCL yet"));
-}
-
-#pragma clang diagnostic pop