Merge remote-tracking branch 'origin/release-2021' into master
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gpu_internal.cpp
index 8b9dea9fab9d5734c2e283c3e28e8736689f119a..185fea7de3420607514eb54726087591e39f63a2 100644 (file)
 #include <string>
 
 #include "gromacs/ewald/ewald_utils.h"
+#include "gromacs/fft/gpu_3dfft.h"
 #include "gromacs/gpu_utils/device_context.h"
 #include "gromacs/gpu_utils/device_stream.h"
 #include "gromacs/gpu_utils/gpu_utils.h"
+#include "gromacs/gpu_utils/pmalloc.h"
+#if GMX_GPU_SYCL
+#    include "gromacs/gpu_utils/syclutils.h"
+#endif
 #include "gromacs/hardware/device_information.h"
 #include "gromacs/math/invertmatrix.h"
 #include "gromacs/math/units.h"
 #include "gromacs/utility/gmxassert.h"
 #include "gromacs/utility/logger.h"
 #include "gromacs/utility/stringutil.h"
+#include "gromacs/ewald/pme.h"
 
 #if GMX_GPU_CUDA
-#    include "gromacs/gpu_utils/pmalloc_cuda.h"
-
 #    include "pme.cuh"
-#elif GMX_GPU_OPENCL
-#    include "gromacs/gpu_utils/gmxopencl.h"
 #endif
 
-#include "gromacs/ewald/pme.h"
-
-#include "pme_gpu_3dfft.h"
 #include "pme_gpu_calculate_splines.h"
 #include "pme_gpu_constants.h"
 #include "pme_gpu_program_impl.h"
@@ -143,7 +142,8 @@ void pme_gpu_alloc_energy_virial(PmeGpu* pmeGpu)
     for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
     {
         allocateDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy[gridIndex],
-                             c_virialAndEnergyCount, pmeGpu->archSpecific->deviceContext_);
+                             c_virialAndEnergyCount,
+                             pmeGpu->archSpecific->deviceContext_);
         pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_virialAndEnergy[gridIndex]), energyAndVirialSize);
     }
 }
@@ -162,8 +162,10 @@ void pme_gpu_clear_energy_virial(const PmeGpu* pmeGpu)
 {
     for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
     {
-        clearDeviceBufferAsync(&pmeGpu->kernelParams->constants.d_virialAndEnergy[gridIndex], 0,
-                               c_virialAndEnergyCount, pmeGpu->archSpecific->pmeStream_);
+        clearDeviceBufferAsync(&pmeGpu->kernelParams->constants.d_virialAndEnergy[gridIndex],
+                               0,
+                               c_virialAndEnergyCount,
+                               pmeGpu->archSpecific->pmeStream_);
     }
 }
 
@@ -175,7 +177,8 @@ void pme_gpu_realloc_and_copy_bspline_values(PmeGpu* pmeGpu, const int gridIndex
     GMX_ASSERT(gridIndex < pmeGpu->common->ngrids,
                "Invalid combination of gridIndex and number of grids");
 
-    const int splineValuesOffset[DIM] = { 0, pmeGpu->kernelParams->grid.realGridSize[XX],
+    const int splineValuesOffset[DIM] = { 0,
+                                          pmeGpu->kernelParams->grid.realGridSize[XX],
                                           pmeGpu->kernelParams->grid.realGridSize[XX]
                                                   + pmeGpu->kernelParams->grid.realGridSize[YY] };
     memcpy(&pmeGpu->kernelParams->grid.splineValuesOffset, &splineValuesOffset, sizeof(splineValuesOffset));
@@ -185,7 +188,8 @@ void pme_gpu_realloc_and_copy_bspline_values(PmeGpu* pmeGpu, const int gridIndex
                                     + pmeGpu->kernelParams->grid.realGridSize[ZZ];
     const bool shouldRealloc = (newSplineValuesSize > pmeGpu->archSpecific->splineValuesSize[gridIndex]);
     reallocateDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli[gridIndex],
-                           newSplineValuesSize, &pmeGpu->archSpecific->splineValuesSize[gridIndex],
+                           newSplineValuesSize,
+                           &pmeGpu->archSpecific->splineValuesSize[gridIndex],
                            &pmeGpu->archSpecific->splineValuesCapacity[gridIndex],
                            pmeGpu->archSpecific->deviceContext_);
     if (shouldRealloc)
@@ -198,12 +202,17 @@ void pme_gpu_realloc_and_copy_bspline_values(PmeGpu* pmeGpu, const int gridIndex
     for (int i = 0; i < DIM; i++)
     {
         memcpy(pmeGpu->staging.h_splineModuli[gridIndex] + splineValuesOffset[i],
-               pmeGpu->common->bsp_mod[i].data(), pmeGpu->common->bsp_mod[i].size() * sizeof(float));
+               pmeGpu->common->bsp_mod[i].data(),
+               pmeGpu->common->bsp_mod[i].size() * sizeof(float));
     }
     /* TODO: pin original buffer instead! */
     copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli[gridIndex],
-                       pmeGpu->staging.h_splineModuli[gridIndex], 0, newSplineValuesSize,
-                       pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
+                       pmeGpu->staging.h_splineModuli[gridIndex],
+                       0,
+                       newSplineValuesSize,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
 void pme_gpu_free_bspline_values(const PmeGpu* pmeGpu)
@@ -217,10 +226,12 @@ void pme_gpu_free_bspline_values(const PmeGpu* pmeGpu)
 
 void pme_gpu_realloc_forces(PmeGpu* pmeGpu)
 {
-    const size_t newForcesSize = pmeGpu->nAtomsAlloc * DIM;
+    const size_t newForcesSize = pmeGpu->nAtomsAlloc;
     GMX_ASSERT(newForcesSize > 0, "Bad number of atoms in PME GPU");
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces, newForcesSize,
-                           &pmeGpu->archSpecific->forcesSize, &pmeGpu->archSpecific->forcesSizeAlloc,
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces,
+                           newForcesSize,
+                           &pmeGpu->archSpecific->forcesSize,
+                           &pmeGpu->archSpecific->forcesSizeAlloc,
                            pmeGpu->archSpecific->deviceContext_);
     pmeGpu->staging.h_forces.reserveWithPadding(pmeGpu->nAtomsAlloc);
     pmeGpu->staging.h_forces.resizeWithPadding(pmeGpu->kernelParams->atoms.nAtoms);
@@ -234,19 +245,25 @@ void pme_gpu_free_forces(const PmeGpu* pmeGpu)
 void pme_gpu_copy_input_forces(PmeGpu* pmeGpu)
 {
     GMX_ASSERT(pmeGpu->kernelParams->atoms.nAtoms > 0, "Bad number of atoms in PME GPU");
-    float* h_forcesFloat = reinterpret_cast<float*>(pmeGpu->staging.h_forces.data());
-    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces, h_forcesFloat, 0,
-                       DIM * pmeGpu->kernelParams->atoms.nAtoms, pmeGpu->archSpecific->pmeStream_,
-                       pmeGpu->settings.transferKind, nullptr);
+    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces,
+                       pmeGpu->staging.h_forces.data(),
+                       0,
+                       pmeGpu->kernelParams->atoms.nAtoms,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
 void pme_gpu_copy_output_forces(PmeGpu* pmeGpu)
 {
     GMX_ASSERT(pmeGpu->kernelParams->atoms.nAtoms > 0, "Bad number of atoms in PME GPU");
-    float* h_forcesFloat = reinterpret_cast<float*>(pmeGpu->staging.h_forces.data());
-    copyFromDeviceBuffer(h_forcesFloat, &pmeGpu->kernelParams->atoms.d_forces, 0,
-                         DIM * pmeGpu->kernelParams->atoms.nAtoms, pmeGpu->archSpecific->pmeStream_,
-                         pmeGpu->settings.transferKind, nullptr);
+    copyFromDeviceBuffer(pmeGpu->staging.h_forces.data(),
+                         &pmeGpu->kernelParams->atoms.d_forces,
+                         0,
+                         pmeGpu->kernelParams->atoms.nAtoms,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
 }
 
 void pme_gpu_realloc_and_copy_input_coefficients(const PmeGpu* pmeGpu,
@@ -257,19 +274,26 @@ void pme_gpu_realloc_and_copy_input_coefficients(const PmeGpu* pmeGpu,
     const size_t newCoefficientsSize = pmeGpu->nAtomsAlloc;
     GMX_ASSERT(newCoefficientsSize > 0, "Bad number of atoms in PME GPU");
     reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex],
-                           newCoefficientsSize, &pmeGpu->archSpecific->coefficientsSize[gridIndex],
+                           newCoefficientsSize,
+                           &pmeGpu->archSpecific->coefficientsSize[gridIndex],
                            &pmeGpu->archSpecific->coefficientsCapacity[gridIndex],
                            pmeGpu->archSpecific->deviceContext_);
     copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex],
-                       const_cast<float*>(h_coefficients), 0, pmeGpu->kernelParams->atoms.nAtoms,
-                       pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
+                       const_cast<float*>(h_coefficients),
+                       0,
+                       pmeGpu->kernelParams->atoms.nAtoms,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 
     const size_t paddingIndex = pmeGpu->kernelParams->atoms.nAtoms;
     const size_t paddingCount = pmeGpu->nAtomsAlloc - paddingIndex;
     if (paddingCount > 0)
     {
-        clearDeviceBufferAsync(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex], paddingIndex,
-                               paddingCount, pmeGpu->archSpecific->pmeStream_);
+        clearDeviceBufferAsync(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex],
+                               paddingIndex,
+                               paddingCount,
+                               pmeGpu->archSpecific->pmeStream_);
     }
 }
 
@@ -290,10 +314,15 @@ void pme_gpu_realloc_spline_data(PmeGpu* pmeGpu)
     const bool shouldRealloc        = (newSplineDataSize > pmeGpu->archSpecific->splineDataSize);
     int        currentSizeTemp      = pmeGpu->archSpecific->splineDataSize;
     int        currentSizeTempAlloc = pmeGpu->archSpecific->splineDataSizeAlloc;
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_theta, newSplineDataSize, &currentSizeTemp,
-                           &currentSizeTempAlloc, pmeGpu->archSpecific->deviceContext_);
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_dtheta, newSplineDataSize,
-                           &pmeGpu->archSpecific->splineDataSize, &pmeGpu->archSpecific->splineDataSizeAlloc,
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_theta,
+                           newSplineDataSize,
+                           &currentSizeTemp,
+                           &currentSizeTempAlloc,
+                           pmeGpu->archSpecific->deviceContext_);
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_dtheta,
+                           newSplineDataSize,
+                           &pmeGpu->archSpecific->splineDataSize,
+                           &pmeGpu->archSpecific->splineDataSizeAlloc,
                            pmeGpu->archSpecific->deviceContext_);
     // the host side reallocation
     if (shouldRealloc)
@@ -318,7 +347,8 @@ void pme_gpu_realloc_grid_indices(PmeGpu* pmeGpu)
 {
     const size_t newIndicesSize = DIM * pmeGpu->nAtomsAlloc;
     GMX_ASSERT(newIndicesSize > 0, "Bad number of atoms in PME GPU");
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_gridlineIndices, newIndicesSize,
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_gridlineIndices,
+                           newIndicesSize,
                            &pmeGpu->archSpecific->gridlineIndicesSize,
                            &pmeGpu->archSpecific->gridlineIndicesSizeAlloc,
                            pmeGpu->archSpecific->deviceContext_);
@@ -348,11 +378,13 @@ void pme_gpu_realloc_grids(PmeGpu* pmeGpu)
         if (pmeGpu->archSpecific->performOutOfPlaceFFT)
         {
             /* 2 separate grids */
-            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid[gridIndex], newComplexGridSize,
+            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid[gridIndex],
+                                   newComplexGridSize,
                                    &pmeGpu->archSpecific->complexGridSize[gridIndex],
                                    &pmeGpu->archSpecific->complexGridCapacity[gridIndex],
                                    pmeGpu->archSpecific->deviceContext_);
-            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_realGrid[gridIndex], newRealGridSize,
+            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_realGrid[gridIndex],
+                                   newRealGridSize,
                                    &pmeGpu->archSpecific->realGridSize[gridIndex],
                                    &pmeGpu->archSpecific->realGridCapacity[gridIndex],
                                    pmeGpu->archSpecific->deviceContext_);
@@ -361,7 +393,8 @@ void pme_gpu_realloc_grids(PmeGpu* pmeGpu)
         {
             /* A single buffer so that any grid will fit */
             const int newGridsSize = std::max(newRealGridSize, newComplexGridSize);
-            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_realGrid[gridIndex], newGridsSize,
+            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_realGrid[gridIndex],
+                                   newGridsSize,
                                    &pmeGpu->archSpecific->realGridSize[gridIndex],
                                    &pmeGpu->archSpecific->realGridCapacity[gridIndex],
                                    pmeGpu->archSpecific->deviceContext_);
@@ -389,7 +422,8 @@ void pme_gpu_clear_grids(const PmeGpu* pmeGpu)
 {
     for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
     {
-        clearDeviceBufferAsync(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex], 0,
+        clearDeviceBufferAsync(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex],
+                               0,
                                pmeGpu->archSpecific->realGridSize[gridIndex],
                                pmeGpu->archSpecific->pmeStream_);
     }
@@ -412,12 +446,16 @@ void pme_gpu_realloc_and_copy_fract_shifts(PmeGpu* pmeGpu)
     const int newFractShiftsSize = cellCount * (nx + ny + nz);
 
     initParamLookupTable(&kernelParamsPtr->grid.d_fractShiftsTable,
-                         &kernelParamsPtr->fractShiftsTableTexture, pmeGpu->common->fsh.data(),
-                         newFractShiftsSize, pmeGpu->archSpecific->deviceContext_);
+                         &kernelParamsPtr->fractShiftsTableTexture,
+                         pmeGpu->common->fsh.data(),
+                         newFractShiftsSize,
+                         pmeGpu->archSpecific->deviceContext_);
 
     initParamLookupTable(&kernelParamsPtr->grid.d_gridlineIndicesTable,
-                         &kernelParamsPtr->gridlineIndicesTableTexture, pmeGpu->common->nn.data(),
-                         newFractShiftsSize, pmeGpu->archSpecific->deviceContext_);
+                         &kernelParamsPtr->gridlineIndicesTableTexture,
+                         pmeGpu->common->nn.data(),
+                         newFractShiftsSize,
+                         pmeGpu->archSpecific->deviceContext_);
 }
 
 void pme_gpu_free_fract_shifts(const PmeGpu* pmeGpu)
@@ -425,10 +463,10 @@ void pme_gpu_free_fract_shifts(const PmeGpu* pmeGpu)
     auto* kernelParamsPtr = pmeGpu->kernelParams.get();
 #if GMX_GPU_CUDA
     destroyParamLookupTable(&kernelParamsPtr->grid.d_fractShiftsTable,
-                            kernelParamsPtr->fractShiftsTableTexture);
+                            &kernelParamsPtr->fractShiftsTableTexture);
     destroyParamLookupTable(&kernelParamsPtr->grid.d_gridlineIndicesTable,
-                            kernelParamsPtr->gridlineIndicesTableTexture);
-#elif GMX_GPU_OPENCL
+                            &kernelParamsPtr->gridlineIndicesTableTexture);
+#elif GMX_GPU_OPENCL || GMX_GPU_SYCL
     freeDeviceBuffer(&kernelParamsPtr->grid.d_fractShiftsTable);
     freeDeviceBuffer(&kernelParamsPtr->grid.d_gridlineIndicesTable);
 #endif
@@ -441,16 +479,24 @@ bool pme_gpu_stream_query(const PmeGpu* pmeGpu)
 
 void pme_gpu_copy_input_gather_grid(const PmeGpu* pmeGpu, const float* h_grid, const int gridIndex)
 {
-    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex], h_grid, 0,
+    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex],
+                       h_grid,
+                       0,
                        pmeGpu->archSpecific->realGridSize[gridIndex],
-                       pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
 void pme_gpu_copy_output_spread_grid(const PmeGpu* pmeGpu, float* h_grid, const int gridIndex)
 {
-    copyFromDeviceBuffer(h_grid, &pmeGpu->kernelParams->grid.d_realGrid[gridIndex], 0,
+    copyFromDeviceBuffer(h_grid,
+                         &pmeGpu->kernelParams->grid.d_realGrid[gridIndex],
+                         0,
                          pmeGpu->archSpecific->realGridSize[gridIndex],
-                         pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
     pmeGpu->archSpecific->syncSpreadGridD2H.markEvent(pmeGpu->archSpecific->pmeStream_);
 }
 
@@ -458,13 +504,27 @@ void pme_gpu_copy_output_spread_atom_data(const PmeGpu* pmeGpu)
 {
     const size_t splinesCount    = DIM * pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order;
     auto*        kernelParamsPtr = pmeGpu->kernelParams.get();
-    copyFromDeviceBuffer(pmeGpu->staging.h_dtheta, &kernelParamsPtr->atoms.d_dtheta, 0, splinesCount,
-                         pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
-    copyFromDeviceBuffer(pmeGpu->staging.h_theta, &kernelParamsPtr->atoms.d_theta, 0, splinesCount,
-                         pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
-    copyFromDeviceBuffer(pmeGpu->staging.h_gridlineIndices, &kernelParamsPtr->atoms.d_gridlineIndices,
-                         0, kernelParamsPtr->atoms.nAtoms * DIM, pmeGpu->archSpecific->pmeStream_,
-                         pmeGpu->settings.transferKind, nullptr);
+    copyFromDeviceBuffer(pmeGpu->staging.h_dtheta,
+                         &kernelParamsPtr->atoms.d_dtheta,
+                         0,
+                         splinesCount,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
+    copyFromDeviceBuffer(pmeGpu->staging.h_theta,
+                         &kernelParamsPtr->atoms.d_theta,
+                         0,
+                         splinesCount,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
+    copyFromDeviceBuffer(pmeGpu->staging.h_gridlineIndices,
+                         &kernelParamsPtr->atoms.d_gridlineIndices,
+                         0,
+                         kernelParamsPtr->atoms.nAtoms * DIM,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
 }
 
 void pme_gpu_copy_input_gather_atom_data(const PmeGpu* pmeGpu)
@@ -473,22 +533,40 @@ void pme_gpu_copy_input_gather_atom_data(const PmeGpu* pmeGpu)
     auto*        kernelParamsPtr = pmeGpu->kernelParams.get();
 
     // TODO: could clear only the padding and not the whole thing, but this is a test-exclusive code anyway
-    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_gridlineIndices, 0, pmeGpu->nAtomsAlloc * DIM,
+    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_gridlineIndices,
+                           0,
+                           pmeGpu->nAtomsAlloc * DIM,
                            pmeGpu->archSpecific->pmeStream_);
-    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_dtheta, 0,
+    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_dtheta,
+                           0,
                            pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order * DIM,
                            pmeGpu->archSpecific->pmeStream_);
-    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_theta, 0,
+    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_theta,
+                           0,
                            pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order * DIM,
                            pmeGpu->archSpecific->pmeStream_);
 
-    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_dtheta, pmeGpu->staging.h_dtheta, 0, splinesCount,
-                       pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
-    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_theta, pmeGpu->staging.h_theta, 0, splinesCount,
-                       pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
-    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_gridlineIndices, pmeGpu->staging.h_gridlineIndices,
-                       0, kernelParamsPtr->atoms.nAtoms * DIM, pmeGpu->archSpecific->pmeStream_,
-                       pmeGpu->settings.transferKind, nullptr);
+    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_dtheta,
+                       pmeGpu->staging.h_dtheta,
+                       0,
+                       splinesCount,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
+    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_theta,
+                       pmeGpu->staging.h_theta,
+                       0,
+                       splinesCount,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
+    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_gridlineIndices,
+                       pmeGpu->staging.h_gridlineIndices,
+                       0,
+                       kernelParamsPtr->atoms.nAtoms * DIM,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
 void pme_gpu_sync_spread_grid(const PmeGpu* pmeGpu)
@@ -528,9 +606,40 @@ void pme_gpu_reinit_3dfft(const PmeGpu* pmeGpu)
     if (pme_gpu_settings(pmeGpu).performGPUFFT)
     {
         pmeGpu->archSpecific->fftSetup.resize(0);
+        const bool         performOutOfPlaceFFT      = pmeGpu->archSpecific->performOutOfPlaceFFT;
+        const bool         allocateGrid              = false;
+        MPI_Comm           comm                      = MPI_COMM_NULL;
+        std::array<int, 1> gridOffsetsInXForEachRank = { 0 };
+        std::array<int, 1> gridOffsetsInYForEachRank = { 0 };
+#if GMX_GPU_CUDA
+        const gmx::FftBackend backend = gmx::FftBackend::Cufft;
+#elif GMX_GPU_OPENCL
+        const gmx::FftBackend backend = gmx::FftBackend::Ocl;
+#elif GMX_GPU_SYCL
+        const gmx::FftBackend backend = gmx::FftBackend::Sycl;
+#else
+        GMX_RELEASE_ASSERT(false, "Unknown GPU backend");
+        const gmx::FftBackend backend = gmx::FftBackend::Count;
+#endif
+
+        PmeGpuGridParams& grid = pme_gpu_get_kernel_params_base_ptr(pmeGpu)->grid;
         for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
         {
-            pmeGpu->archSpecific->fftSetup.push_back(std::make_unique<GpuParallel3dFft>(pmeGpu, gridIndex));
+            pmeGpu->archSpecific->fftSetup.push_back(
+                    std::make_unique<gmx::Gpu3dFft>(backend,
+                                                    allocateGrid,
+                                                    comm,
+                                                    gridOffsetsInXForEachRank,
+                                                    gridOffsetsInYForEachRank,
+                                                    grid.realGridSize[ZZ],
+                                                    performOutOfPlaceFFT,
+                                                    pmeGpu->archSpecific->deviceContext_,
+                                                    pmeGpu->archSpecific->pmeStream_,
+                                                    grid.realGridSize,
+                                                    grid.realGridSizePadded,
+                                                    grid.complexGridSizePadded,
+                                                    &(grid.d_realGrid[gridIndex]),
+                                                    &(grid.d_fourierGrid[gridIndex])));
         }
     }
 }
@@ -750,7 +859,7 @@ static void pme_gpu_copy_common_data_from(const gmx_pme_t* pme)
     pmeGpu->common->nn.insert(pmeGpu->common->nn.end(), pme->nnz, pme->nnz + cellCount * pme->nkz);
     pmeGpu->common->runMode       = pme->runMode;
     pmeGpu->common->isRankPmeOnly = !pme->bPPnode;
-    pmeGpu->common->boxScaler     = pme->boxScaler;
+    pmeGpu->common->boxScaler     = pme->boxScaler.get();
 }
 
 /*! \libinternal \brief
@@ -815,7 +924,7 @@ static void pme_gpu_init(gmx_pme_t*           pme,
     GMX_ASSERT(pmeGpu->common->epsilon_r != 0.0F, "PME GPU: bad electrostatic coefficient");
 
     auto* kernelParamsPtr               = pme_gpu_get_kernel_params_base_ptr(pmeGpu);
-    kernelParamsPtr->constants.elFactor = ONE_4PI_EPS0 / pmeGpu->common->epsilon_r;
+    kernelParamsPtr->constants.elFactor = gmx::c_one4PiEps0 / pmeGpu->common->epsilon_r;
 }
 
 void pme_gpu_get_real_grid_sizes(const PmeGpu* pmeGpu, gmx::IVec* gridSize, gmx::IVec* paddedGridSize)
@@ -935,23 +1044,23 @@ void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* chargesA
  * In CUDA result can be nullptr stub, per GpuRegionTimer implementation.
  *
  * \param[in] pmeGpu         The PME GPU data structure.
- * \param[in] PMEStageId     The PME GPU stage gtPME_ index from the enum in src/gromacs/timing/gpu_timing.h
+ * \param[in] pmeStageId     The PME GPU stage gtPME_ index from the enum in src/gromacs/timing/gpu_timing.h
  */
-static CommandEvent* pme_gpu_fetch_timing_event(const PmeGpu* pmeGpu, size_t PMEStageId)
+static CommandEvent* pme_gpu_fetch_timing_event(const PmeGpu* pmeGpu, PmeStage pmeStageId)
 {
     CommandEvent* timingEvent = nullptr;
     if (pme_gpu_timings_enabled(pmeGpu))
     {
-        GMX_ASSERT(PMEStageId < pmeGpu->archSpecific->timingEvents.size(),
-                   "Wrong PME GPU timing event index");
-        timingEvent = pmeGpu->archSpecific->timingEvents[PMEStageId].fetchNextEvent();
+        GMX_ASSERT(pmeStageId < PmeStage::Count, "Wrong PME GPU timing event index");
+        timingEvent = pmeGpu->archSpecific->timingEvents[pmeStageId].fetchNextEvent();
     }
     return timingEvent;
 }
 
 void pme_gpu_3dfft(const PmeGpu* pmeGpu, gmx_fft_direction dir, const int grid_index)
 {
-    int timerId = (dir == GMX_FFT_REAL_TO_COMPLEX) ? gtPME_FFT_R2C : gtPME_FFT_C2R;
+    PmeStage timerId = (dir == GMX_FFT_REAL_TO_COMPLEX) ? PmeStage::FftTransformR2C
+                                                        : PmeStage::FftTransformC2R;
 
     pme_gpu_start_timing(pmeGpu, timerId);
     pmeGpu->archSpecific->fftSetup[grid_index]->perform3dFft(
@@ -1055,8 +1164,8 @@ static auto selectSplineAndSpreadKernelPtr(const PmeGpu*  pmeGpu,
  *
  * \return Pointer to CUDA kernel
  */
-static auto selectSplineKernelPtr(const PmeGpu*  pmeGpu,
-                                  ThreadsPerAtom threadsPerAtom,
+static auto selectSplineKernelPtr(const PmeGpu*   pmeGpu,
+                                  ThreadsPerAtom  threadsPerAtom,
                                   bool gmx_unused writeSplinesToGlobal,
                                   const int       numGrids)
 {
@@ -1233,30 +1342,34 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
     config.gridSize[0]  = dimGrid.first;
     config.gridSize[1]  = dimGrid.second;
 
-    int                                timingId;
+    PmeStage                           timingId;
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (computeSplines)
     {
         if (spreadCharges)
         {
-            timingId  = gtPME_SPLINEANDSPREAD;
-            kernelPtr = selectSplineAndSpreadKernelPtr(pmeGpu, pmeGpu->settings.threadsPerAtom,
+            timingId  = PmeStage::SplineAndSpread;
+            kernelPtr = selectSplineAndSpreadKernelPtr(pmeGpu,
+                                                       pmeGpu->settings.threadsPerAtom,
                                                        writeGlobal || (!recalculateSplines),
                                                        pmeGpu->common->ngrids);
         }
         else
         {
-            timingId  = gtPME_SPLINE;
-            kernelPtr = selectSplineKernelPtr(pmeGpu, pmeGpu->settings.threadsPerAtom,
+            timingId  = PmeStage::Spline;
+            kernelPtr = selectSplineKernelPtr(pmeGpu,
+                                              pmeGpu->settings.threadsPerAtom,
                                               writeGlobal || (!recalculateSplines),
                                               pmeGpu->common->ngrids);
         }
     }
     else
     {
-        timingId  = gtPME_SPREAD;
-        kernelPtr = selectSpreadKernelPtr(pmeGpu, pmeGpu->settings.threadsPerAtom,
-                                          writeGlobal || (!recalculateSplines), pmeGpu->common->ngrids);
+        timingId  = PmeStage::Spread;
+        kernelPtr = selectSpreadKernelPtr(pmeGpu,
+                                          pmeGpu->settings.threadsPerAtom,
+                                          writeGlobal || (!recalculateSplines),
+                                          pmeGpu->common->ngrids);
     }
 
 
@@ -1265,17 +1378,24 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
 #if c_canEmbedBuffers
     const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
 #else
-    const auto kernelArgs = prepareGpuKernelArguments(
-            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->atoms.d_theta,
-            &kernelParamsPtr->atoms.d_dtheta, &kernelParamsPtr->atoms.d_gridlineIndices,
-            &kernelParamsPtr->grid.d_realGrid[FEP_STATE_A], &kernelParamsPtr->grid.d_realGrid[FEP_STATE_B],
-            &kernelParamsPtr->grid.d_fractShiftsTable, &kernelParamsPtr->grid.d_gridlineIndicesTable,
-            &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_A],
-            &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_B], &kernelParamsPtr->atoms.d_coordinates);
+    const auto kernelArgs =
+            prepareGpuKernelArguments(kernelPtr,
+                                      config,
+                                      kernelParamsPtr,
+                                      &kernelParamsPtr->atoms.d_theta,
+                                      &kernelParamsPtr->atoms.d_dtheta,
+                                      &kernelParamsPtr->atoms.d_gridlineIndices,
+                                      &kernelParamsPtr->grid.d_realGrid[FEP_STATE_A],
+                                      &kernelParamsPtr->grid.d_realGrid[FEP_STATE_B],
+                                      &kernelParamsPtr->grid.d_fractShiftsTable,
+                                      &kernelParamsPtr->grid.d_gridlineIndicesTable,
+                                      &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_A],
+                                      &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_B],
+                                      &kernelParamsPtr->atoms.d_coordinates);
 #endif
 
-    launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent,
-                    "PME spline/spread", kernelArgs);
+    launchGpuKernel(
+            kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent, "PME spline/spread", kernelArgs);
     pme_gpu_stop_timing(pmeGpu, timingId);
 
     const auto& settings    = pmeGpu->settings;
@@ -1316,9 +1436,13 @@ void pme_gpu_solve(const PmeGpu* pmeGpu,
     float* h_gridFloat = reinterpret_cast<float*>(h_grid);
     if (copyInputAndOutputGrid)
     {
-        copyToDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid[gridIndex], h_gridFloat, 0,
+        copyToDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid[gridIndex],
+                           h_gridFloat,
+                           0,
                            pmeGpu->archSpecific->complexGridSize[gridIndex],
-                           pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
+                           pmeGpu->archSpecific->pmeStream_,
+                           pmeGpu->settings.transferKind,
+                           nullptr);
     }
 
     int majorDim = -1, middleDim = -1, minorDim = -1;
@@ -1368,7 +1492,7 @@ void pme_gpu_solve(const PmeGpu* pmeGpu,
                          / gridLinesPerBlock;
     config.gridSize[2] = pmeGpu->kernelParams->grid.complexGridSize[majorDim];
 
-    int                                timingId  = gtPME_SOLVE;
+    PmeStage                           timingId  = PmeStage::Solve;
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (gridOrdering == GridOrdering::YZX)
     {
@@ -1402,28 +1526,37 @@ void pme_gpu_solve(const PmeGpu* pmeGpu,
 #if c_canEmbedBuffers
     const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
 #else
-    const auto kernelArgs = prepareGpuKernelArguments(
-            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->grid.d_splineModuli[gridIndex],
-            &kernelParamsPtr->constants.d_virialAndEnergy[gridIndex],
-            &kernelParamsPtr->grid.d_fourierGrid[gridIndex]);
+    const auto kernelArgs =
+            prepareGpuKernelArguments(kernelPtr,
+                                      config,
+                                      kernelParamsPtr,
+                                      &kernelParamsPtr->grid.d_splineModuli[gridIndex],
+                                      &kernelParamsPtr->constants.d_virialAndEnergy[gridIndex],
+                                      &kernelParamsPtr->grid.d_fourierGrid[gridIndex]);
 #endif
-    launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent, "PME solve",
-                    kernelArgs);
+    launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent, "PME solve", kernelArgs);
     pme_gpu_stop_timing(pmeGpu, timingId);
 
     if (computeEnergyAndVirial)
     {
         copyFromDeviceBuffer(pmeGpu->staging.h_virialAndEnergy[gridIndex],
-                             &kernelParamsPtr->constants.d_virialAndEnergy[gridIndex], 0,
-                             c_virialAndEnergyCount, pmeGpu->archSpecific->pmeStream_,
-                             pmeGpu->settings.transferKind, nullptr);
+                             &kernelParamsPtr->constants.d_virialAndEnergy[gridIndex],
+                             0,
+                             c_virialAndEnergyCount,
+                             pmeGpu->archSpecific->pmeStream_,
+                             pmeGpu->settings.transferKind,
+                             nullptr);
     }
 
     if (copyInputAndOutputGrid)
     {
-        copyFromDeviceBuffer(h_gridFloat, &kernelParamsPtr->grid.d_fourierGrid[gridIndex], 0,
+        copyFromDeviceBuffer(h_gridFloat,
+                             &kernelParamsPtr->grid.d_fourierGrid[gridIndex],
+                             0,
                              pmeGpu->archSpecific->complexGridSize[gridIndex],
-                             pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
+                             pmeGpu->archSpecific->pmeStream_,
+                             pmeGpu->settings.transferKind,
+                             nullptr);
     }
 }
 
@@ -1551,10 +1684,12 @@ void pme_gpu_gather(PmeGpu* pmeGpu, real** h_grids, const float lambda)
 
     // TODO test different cache configs
 
-    int                                timingId = gtPME_GATHER;
+    PmeStage                           timingId = PmeStage::Gather;
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr =
-            selectGatherKernelPtr(pmeGpu, pmeGpu->settings.threadsPerAtom,
-                                  readGlobal || (!recalculateSplines), pmeGpu->common->ngrids);
+            selectGatherKernelPtr(pmeGpu,
+                                  pmeGpu->settings.threadsPerAtom,
+                                  readGlobal || (!recalculateSplines),
+                                  pmeGpu->common->ngrids);
     // TODO design kernel selection getters and make PmeGpu a friend of PmeGpuProgramImpl
 
     pme_gpu_start_timing(pmeGpu, timingId);
@@ -1572,15 +1707,20 @@ void pme_gpu_gather(PmeGpu* pmeGpu, real** h_grids, const float lambda)
 #if c_canEmbedBuffers
     const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
 #else
-    const auto kernelArgs = prepareGpuKernelArguments(
-            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_A],
-            &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_B],
-            &kernelParamsPtr->grid.d_realGrid[FEP_STATE_A], &kernelParamsPtr->grid.d_realGrid[FEP_STATE_B],
-            &kernelParamsPtr->atoms.d_theta, &kernelParamsPtr->atoms.d_dtheta,
-            &kernelParamsPtr->atoms.d_gridlineIndices, &kernelParamsPtr->atoms.d_forces);
+    const auto kernelArgs =
+            prepareGpuKernelArguments(kernelPtr,
+                                      config,
+                                      kernelParamsPtr,
+                                      &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_A],
+                                      &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_B],
+                                      &kernelParamsPtr->grid.d_realGrid[FEP_STATE_A],
+                                      &kernelParamsPtr->grid.d_realGrid[FEP_STATE_B],
+                                      &kernelParamsPtr->atoms.d_theta,
+                                      &kernelParamsPtr->atoms.d_dtheta,
+                                      &kernelParamsPtr->atoms.d_gridlineIndices,
+                                      &kernelParamsPtr->atoms.d_forces);
 #endif
-    launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent, "PME gather",
-                    kernelArgs);
+    launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent, "PME gather", kernelArgs);
     pme_gpu_stop_timing(pmeGpu, timingId);
 
     if (pmeGpu->settings.useGpuForceReduction)
@@ -1593,7 +1733,7 @@ void pme_gpu_gather(PmeGpu* pmeGpu, real** h_grids, const float lambda)
     }
 }
 
-void* pme_gpu_get_kernelparam_forces(const PmeGpu* pmeGpu)
+DeviceBuffer<gmx::RVec> pme_gpu_get_kernelparam_forces(const PmeGpu* pmeGpu)
 {
     if (pmeGpu && pmeGpu->kernelParams)
     {
@@ -1601,7 +1741,7 @@ void* pme_gpu_get_kernelparam_forces(const PmeGpu* pmeGpu)
     }
     else
     {
-        return nullptr;
+        return DeviceBuffer<gmx::RVec>{};
     }
 }