Unify coordinate copy handling across GPU platforms
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gpu_internal.cpp
index df5f8e84ccdfb65b9b877a5dd4a7e37482c43d20..5e5bfd8ebeac7cdd9b6cf22ebff2f2dbf0cfc1ab 100644 (file)
@@ -1,7 +1,8 @@
 /*
  * This file is part of the GROMACS molecular simulation package.
  *
- * Copyright (c) 2016,2017,2018,2019, by the GROMACS development team, led by
+ * Copyright (c) 2016,2017,2018,2019,2020 by the GROMACS development team.
+ * Copyright (c) 2021, by the GROMACS development team, led by
  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
  * and including many others, as listed in the AUTHORS file in the
  * top-level source directory and at http://www.gromacs.org.
 #include <string>
 
 #include "gromacs/ewald/ewald_utils.h"
+#include "gromacs/fft/gpu_3dfft.h"
+#include "gromacs/gpu_utils/device_context.h"
+#include "gromacs/gpu_utils/device_stream.h"
 #include "gromacs/gpu_utils/gpu_utils.h"
+#include "gromacs/gpu_utils/pmalloc.h"
+#if GMX_GPU_SYCL
+#    include "gromacs/gpu_utils/syclutils.h"
+#endif
+#include "gromacs/hardware/device_information.h"
 #include "gromacs/math/invertmatrix.h"
 #include "gromacs/math/units.h"
 #include "gromacs/timing/gpu_timing.h"
 #include "gromacs/utility/gmxassert.h"
 #include "gromacs/utility/logger.h"
 #include "gromacs/utility/stringutil.h"
+#include "gromacs/ewald/pme.h"
+#include "gromacs/ewald/pme_coordinate_receiver_gpu.h"
 
-#if GMX_GPU == GMX_GPU_CUDA
-#    include "gromacs/gpu_utils/pmalloc_cuda.h"
-
+#if GMX_GPU_CUDA
 #    include "pme.cuh"
-#elif GMX_GPU == GMX_GPU_OPENCL
-#    include "gromacs/gpu_utils/gmxopencl.h"
 #endif
 
-#include "gromacs/ewald/pme.h"
-
-#include "pme_gpu_3dfft.h"
+#include "pme_gpu_calculate_splines.h"
 #include "pme_gpu_constants.h"
 #include "pme_gpu_program_impl.h"
 #include "pme_gpu_timings.h"
 #include "pme_gpu_types.h"
 #include "pme_gpu_types_host.h"
 #include "pme_gpu_types_host_impl.h"
-#include "pme_gpu_utils.h"
 #include "pme_grid.h"
 #include "pme_internal.h"
 #include "pme_solve.h"
 
-
 /*! \brief
- *  CUDA only.
- *  Controls if we should use order (i.e. 4) threads per atom for the GPU
- *  or order*order (i.e. 16) threads per atom.
+ * CUDA only
+ * Atom limit above which it is advantageous to turn on the
+ * recalculating of the splines in the gather and using less threads per atom in the spline and spread
  */
-constexpr bool c_useOrderThreadsPerAtom = false;
-/*! \brief
- * CUDA only.
- * Controls if we should recalculate the splines in the gather or
- * save the values in the spread and reload in the gather.
- */
-constexpr bool c_recalculateSplines = false;
+constexpr int c_pmeGpuPerformanceAtomLimit = 23000;
 
 /*! \internal \brief
  * Wrapper for getting a pointer to the plain C++ part of the GPU kernel parameters structure.
@@ -115,60 +112,74 @@ static PmeGpuKernelParamsBase* pme_gpu_get_kernel_params_base_ptr(const PmeGpu*
     return kernelParamsPtr;
 }
 
-int pme_gpu_get_atom_data_alignment(const PmeGpu* /*unused*/)
-{
-    // TODO: this can be simplified, as c_pmeAtomDataAlignment is now constant
-    if (c_usePadding)
-    {
-        return c_pmeAtomDataAlignment;
-    }
-    else
-    {
-        return 0;
-    }
-}
+/*! \brief
+ * Atom data block size (in terms of number of atoms).
+ * This is the least common multiple of number of atoms processed by
+ * a single block/workgroup of the spread and gather kernels.
+ * The GPU atom data buffers must be padded, which means that
+ * the numbers of atoms used for determining the size of the memory
+ * allocation must be divisible by this.
+ */
+constexpr int c_pmeAtomDataBlockSize = 64;
 
-int pme_gpu_get_atoms_per_warp(const PmeGpu* pmeGpu)
+int pme_gpu_get_atom_data_block_size()
 {
-    if (c_useOrderThreadsPerAtom)
-    {
-        return pmeGpu->programHandle_->impl_->warpSize / c_pmeSpreadGatherThreadsPerAtom4ThPerAtom;
-    }
-    else
-    {
-        return pmeGpu->programHandle_->impl_->warpSize / c_pmeSpreadGatherThreadsPerAtom;
-    }
+    return c_pmeAtomDataBlockSize;
 }
 
 void pme_gpu_synchronize(const PmeGpu* pmeGpu)
 {
-    gpuStreamSynchronize(pmeGpu->archSpecific->pmeStream);
+    pmeGpu->archSpecific->pmeStream_.synchronize();
 }
 
 void pme_gpu_alloc_energy_virial(PmeGpu* pmeGpu)
 {
     const size_t energyAndVirialSize = c_virialAndEnergyCount * sizeof(float);
-    allocateDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy, c_virialAndEnergyCount,
-                         pmeGpu->archSpecific->context);
-    pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_virialAndEnergy), energyAndVirialSize);
+
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        allocateDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy[gridIndex],
+                             c_virialAndEnergyCount,
+                             pmeGpu->archSpecific->deviceContext_);
+        pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_virialAndEnergy[gridIndex]), energyAndVirialSize);
+    }
 }
 
 void pme_gpu_free_energy_virial(PmeGpu* pmeGpu)
 {
-    freeDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy);
-    pfree(pmeGpu->staging.h_virialAndEnergy);
-    pmeGpu->staging.h_virialAndEnergy = nullptr;
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        freeDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy[gridIndex]);
+        pfree(pmeGpu->staging.h_virialAndEnergy[gridIndex]);
+        pmeGpu->staging.h_virialAndEnergy[gridIndex] = nullptr;
+    }
 }
 
 void pme_gpu_clear_energy_virial(const PmeGpu* pmeGpu)
 {
-    clearDeviceBufferAsync(&pmeGpu->kernelParams->constants.d_virialAndEnergy, 0,
-                           c_virialAndEnergyCount, pmeGpu->archSpecific->pmeStream);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        clearDeviceBufferAsync(&pmeGpu->kernelParams->constants.d_virialAndEnergy[gridIndex],
+                               0,
+                               c_virialAndEnergyCount,
+                               pmeGpu->archSpecific->pmeStream_);
+    }
 }
 
-void pme_gpu_realloc_and_copy_bspline_values(PmeGpu* pmeGpu)
+void pme_gpu_realloc_and_copy_bspline_values(PmeGpu* pmeGpu, const int gridIndex)
 {
-    const int splineValuesOffset[DIM] = { 0, pmeGpu->kernelParams->grid.realGridSize[XX],
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+    GMX_ASSERT(gridIndex < pmeGpu->common->ngrids,
+               "Invalid combination of gridIndex and number of grids");
+
+    const int splineValuesOffset[DIM] = { 0,
+                                          pmeGpu->kernelParams->grid.realGridSize[XX],
                                           pmeGpu->kernelParams->grid.realGridSize[XX]
                                                   + pmeGpu->kernelParams->grid.realGridSize[YY] };
     memcpy(&pmeGpu->kernelParams->grid.splineValuesOffset, &splineValuesOffset, sizeof(splineValuesOffset));
@@ -176,41 +187,53 @@ void pme_gpu_realloc_and_copy_bspline_values(PmeGpu* pmeGpu)
     const int newSplineValuesSize = pmeGpu->kernelParams->grid.realGridSize[XX]
                                     + pmeGpu->kernelParams->grid.realGridSize[YY]
                                     + pmeGpu->kernelParams->grid.realGridSize[ZZ];
-    const bool shouldRealloc = (newSplineValuesSize > pmeGpu->archSpecific->splineValuesSize);
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli, newSplineValuesSize,
-                           &pmeGpu->archSpecific->splineValuesSize,
-                           &pmeGpu->archSpecific->splineValuesSizeAlloc, pmeGpu->archSpecific->context);
+    const bool shouldRealloc = (newSplineValuesSize > pmeGpu->archSpecific->splineValuesSize[gridIndex]);
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli[gridIndex],
+                           newSplineValuesSize,
+                           &pmeGpu->archSpecific->splineValuesSize[gridIndex],
+                           &pmeGpu->archSpecific->splineValuesCapacity[gridIndex],
+                           pmeGpu->archSpecific->deviceContext_);
     if (shouldRealloc)
     {
         /* Reallocate the host buffer */
-        pfree(pmeGpu->staging.h_splineModuli);
-        pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_splineModuli),
+        pfree(pmeGpu->staging.h_splineModuli[gridIndex]);
+        pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_splineModuli[gridIndex]),
                 newSplineValuesSize * sizeof(float));
     }
     for (int i = 0; i < DIM; i++)
     {
-        memcpy(pmeGpu->staging.h_splineModuli + splineValuesOffset[i],
-               pmeGpu->common->bsp_mod[i].data(), pmeGpu->common->bsp_mod[i].size() * sizeof(float));
+        memcpy(pmeGpu->staging.h_splineModuli[gridIndex] + splineValuesOffset[i],
+               pmeGpu->common->bsp_mod[i].data(),
+               pmeGpu->common->bsp_mod[i].size() * sizeof(float));
     }
     /* TODO: pin original buffer instead! */
-    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli, pmeGpu->staging.h_splineModuli,
-                       0, newSplineValuesSize, pmeGpu->archSpecific->pmeStream,
-                       pmeGpu->settings.transferKind, nullptr);
+    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli[gridIndex],
+                       pmeGpu->staging.h_splineModuli[gridIndex],
+                       0,
+                       newSplineValuesSize,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
 void pme_gpu_free_bspline_values(const PmeGpu* pmeGpu)
 {
-    pfree(pmeGpu->staging.h_splineModuli);
-    freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        pfree(pmeGpu->staging.h_splineModuli[gridIndex]);
+        freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli[gridIndex]);
+    }
 }
 
 void pme_gpu_realloc_forces(PmeGpu* pmeGpu)
 {
-    const size_t newForcesSize = pmeGpu->nAtomsAlloc * DIM;
+    const size_t newForcesSize = pmeGpu->nAtomsAlloc;
     GMX_ASSERT(newForcesSize > 0, "Bad number of atoms in PME GPU");
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces, newForcesSize,
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces,
+                           newForcesSize,
                            &pmeGpu->archSpecific->forcesSize,
-                           &pmeGpu->archSpecific->forcesSizeAlloc, pmeGpu->archSpecific->context);
+                           &pmeGpu->archSpecific->forcesSizeAlloc,
+                           pmeGpu->archSpecific->deviceContext_);
     pmeGpu->staging.h_forces.reserveWithPadding(pmeGpu->nAtomsAlloc);
     pmeGpu->staging.h_forces.resizeWithPadding(pmeGpu->kernelParams->atoms.nAtoms);
 }
@@ -223,89 +246,85 @@ void pme_gpu_free_forces(const PmeGpu* pmeGpu)
 void pme_gpu_copy_input_forces(PmeGpu* pmeGpu)
 {
     GMX_ASSERT(pmeGpu->kernelParams->atoms.nAtoms > 0, "Bad number of atoms in PME GPU");
-    float* h_forcesFloat = reinterpret_cast<float*>(pmeGpu->staging.h_forces.data());
-    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces, h_forcesFloat, 0,
-                       DIM * pmeGpu->kernelParams->atoms.nAtoms, pmeGpu->archSpecific->pmeStream,
-                       pmeGpu->settings.transferKind, nullptr);
+    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces,
+                       pmeGpu->staging.h_forces.data(),
+                       0,
+                       pmeGpu->kernelParams->atoms.nAtoms,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
 void pme_gpu_copy_output_forces(PmeGpu* pmeGpu)
 {
     GMX_ASSERT(pmeGpu->kernelParams->atoms.nAtoms > 0, "Bad number of atoms in PME GPU");
-    float* h_forcesFloat = reinterpret_cast<float*>(pmeGpu->staging.h_forces.data());
-    copyFromDeviceBuffer(h_forcesFloat, &pmeGpu->kernelParams->atoms.d_forces, 0,
-                         DIM * pmeGpu->kernelParams->atoms.nAtoms, pmeGpu->archSpecific->pmeStream,
-                         pmeGpu->settings.transferKind, nullptr);
-}
-
-void pme_gpu_realloc_coordinates(const PmeGpu* pmeGpu)
-{
-    const size_t newCoordinatesSize = pmeGpu->nAtomsAlloc * DIM;
-    GMX_ASSERT(newCoordinatesSize > 0, "Bad number of atoms in PME GPU");
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coordinates, newCoordinatesSize,
-                           &pmeGpu->archSpecific->coordinatesSize,
-                           &pmeGpu->archSpecific->coordinatesSizeAlloc, pmeGpu->archSpecific->context);
-    if (c_usePadding)
-    {
-        const size_t paddingIndex = DIM * pmeGpu->kernelParams->atoms.nAtoms;
-        const size_t paddingCount = DIM * pmeGpu->nAtomsAlloc - paddingIndex;
-        if (paddingCount > 0)
-        {
-            clearDeviceBufferAsync(&pmeGpu->kernelParams->atoms.d_coordinates, paddingIndex,
-                                   paddingCount, pmeGpu->archSpecific->pmeStream);
-        }
-    }
+    copyFromDeviceBuffer(pmeGpu->staging.h_forces.data(),
+                         &pmeGpu->kernelParams->atoms.d_forces,
+                         0,
+                         pmeGpu->kernelParams->atoms.nAtoms,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
 }
 
-void pme_gpu_free_coordinates(const PmeGpu* pmeGpu)
-{
-    freeDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coordinates);
-}
-
-void pme_gpu_realloc_and_copy_input_coefficients(const PmeGpu* pmeGpu, const float* h_coefficients)
+void pme_gpu_realloc_and_copy_input_coefficients(const PmeGpu* pmeGpu,
+                                                 const float*  h_coefficients,
+                                                 const int     gridIndex)
 {
     GMX_ASSERT(h_coefficients, "Bad host-side charge buffer in PME GPU");
     const size_t newCoefficientsSize = pmeGpu->nAtomsAlloc;
     GMX_ASSERT(newCoefficientsSize > 0, "Bad number of atoms in PME GPU");
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients, newCoefficientsSize,
-                           &pmeGpu->archSpecific->coefficientsSize,
-                           &pmeGpu->archSpecific->coefficientsSizeAlloc, pmeGpu->archSpecific->context);
-    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients,
-                       const_cast<float*>(h_coefficients), 0, pmeGpu->kernelParams->atoms.nAtoms,
-                       pmeGpu->archSpecific->pmeStream, pmeGpu->settings.transferKind, nullptr);
-    if (c_usePadding)
-    {
-        const size_t paddingIndex = pmeGpu->kernelParams->atoms.nAtoms;
-        const size_t paddingCount = pmeGpu->nAtomsAlloc - paddingIndex;
-        if (paddingCount > 0)
-        {
-            clearDeviceBufferAsync(&pmeGpu->kernelParams->atoms.d_coefficients, paddingIndex,
-                                   paddingCount, pmeGpu->archSpecific->pmeStream);
-        }
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex],
+                           newCoefficientsSize,
+                           &pmeGpu->archSpecific->coefficientsSize[gridIndex],
+                           &pmeGpu->archSpecific->coefficientsCapacity[gridIndex],
+                           pmeGpu->archSpecific->deviceContext_);
+    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex],
+                       const_cast<float*>(h_coefficients),
+                       0,
+                       pmeGpu->kernelParams->atoms.nAtoms,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
+
+    const size_t paddingIndex = pmeGpu->kernelParams->atoms.nAtoms;
+    const size_t paddingCount = pmeGpu->nAtomsAlloc - paddingIndex;
+    if (paddingCount > 0)
+    {
+        clearDeviceBufferAsync(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex],
+                               paddingIndex,
+                               paddingCount,
+                               pmeGpu->archSpecific->pmeStream_);
     }
 }
 
 void pme_gpu_free_coefficients(const PmeGpu* pmeGpu)
 {
-    freeDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        freeDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex]);
+    }
 }
 
 void pme_gpu_realloc_spline_data(PmeGpu* pmeGpu)
 {
-    const int    order        = pmeGpu->common->pme_order;
-    const int    alignment    = pme_gpu_get_atoms_per_warp(pmeGpu);
-    const size_t nAtomsPadded = ((pmeGpu->nAtomsAlloc + alignment - 1) / alignment) * alignment;
-    const int    newSplineDataSize = DIM * order * nAtomsPadded;
+    const int order             = pmeGpu->common->pme_order;
+    const int newSplineDataSize = DIM * order * pmeGpu->nAtomsAlloc;
     GMX_ASSERT(newSplineDataSize > 0, "Bad number of atoms in PME GPU");
     /* Two arrays of the same size */
     const bool shouldRealloc        = (newSplineDataSize > pmeGpu->archSpecific->splineDataSize);
     int        currentSizeTemp      = pmeGpu->archSpecific->splineDataSize;
     int        currentSizeTempAlloc = pmeGpu->archSpecific->splineDataSizeAlloc;
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_theta, newSplineDataSize,
-                           &currentSizeTemp, &currentSizeTempAlloc, pmeGpu->archSpecific->context);
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_dtheta, newSplineDataSize,
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_theta,
+                           newSplineDataSize,
+                           &currentSizeTemp,
+                           &currentSizeTempAlloc,
+                           pmeGpu->archSpecific->deviceContext_);
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_dtheta,
+                           newSplineDataSize,
                            &pmeGpu->archSpecific->splineDataSize,
-                           &pmeGpu->archSpecific->splineDataSizeAlloc, pmeGpu->archSpecific->context);
+                           &pmeGpu->archSpecific->splineDataSizeAlloc,
+                           pmeGpu->archSpecific->deviceContext_);
     // the host side reallocation
     if (shouldRealloc)
     {
@@ -329,9 +348,11 @@ void pme_gpu_realloc_grid_indices(PmeGpu* pmeGpu)
 {
     const size_t newIndicesSize = DIM * pmeGpu->nAtomsAlloc;
     GMX_ASSERT(newIndicesSize > 0, "Bad number of atoms in PME GPU");
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_gridlineIndices, newIndicesSize,
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_gridlineIndices,
+                           newIndicesSize,
                            &pmeGpu->archSpecific->gridlineIndicesSize,
-                           &pmeGpu->archSpecific->gridlineIndicesSizeAlloc, pmeGpu->archSpecific->context);
+                           &pmeGpu->archSpecific->gridlineIndicesSizeAlloc,
+                           pmeGpu->archSpecific->deviceContext_);
     pfree(pmeGpu->staging.h_gridlineIndices);
     pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_gridlineIndices), newIndicesSize * sizeof(int));
 }
@@ -344,50 +365,69 @@ void pme_gpu_free_grid_indices(const PmeGpu* pmeGpu)
 
 void pme_gpu_realloc_grids(PmeGpu* pmeGpu)
 {
-    auto*     kernelParamsPtr = pmeGpu->kernelParams.get();
+    auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+
     const int newRealGridSize = kernelParamsPtr->grid.realGridSizePadded[XX]
                                 * kernelParamsPtr->grid.realGridSizePadded[YY]
                                 * kernelParamsPtr->grid.realGridSizePadded[ZZ];
     const int newComplexGridSize = kernelParamsPtr->grid.complexGridSizePadded[XX]
                                    * kernelParamsPtr->grid.complexGridSizePadded[YY]
                                    * kernelParamsPtr->grid.complexGridSizePadded[ZZ] * 2;
-    // Multiplied by 2 because we count complex grid size for complex numbers, but all allocations/pointers are float
-    if (pmeGpu->archSpecific->performOutOfPlaceFFT)
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
     {
-        /* 2 separate grids */
-        reallocateDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid, newComplexGridSize,
-                               &pmeGpu->archSpecific->complexGridSize,
-                               &pmeGpu->archSpecific->complexGridSizeAlloc, pmeGpu->archSpecific->context);
-        reallocateDeviceBuffer(&kernelParamsPtr->grid.d_realGrid, newRealGridSize,
-                               &pmeGpu->archSpecific->realGridSize,
-                               &pmeGpu->archSpecific->realGridSizeAlloc, pmeGpu->archSpecific->context);
-    }
-    else
-    {
-        /* A single buffer so that any grid will fit */
-        const int newGridsSize = std::max(newRealGridSize, newComplexGridSize);
-        reallocateDeviceBuffer(
-                &kernelParamsPtr->grid.d_realGrid, newGridsSize, &pmeGpu->archSpecific->realGridSize,
-                &pmeGpu->archSpecific->realGridSizeAlloc, pmeGpu->archSpecific->context);
-        kernelParamsPtr->grid.d_fourierGrid   = kernelParamsPtr->grid.d_realGrid;
-        pmeGpu->archSpecific->complexGridSize = pmeGpu->archSpecific->realGridSize;
-        // the size might get used later for copying the grid
+        // Multiplied by 2 because we count complex grid size for complex numbers, but all allocations/pointers are float
+        if (pmeGpu->archSpecific->performOutOfPlaceFFT)
+        {
+            /* 2 separate grids */
+            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid[gridIndex],
+                                   newComplexGridSize,
+                                   &pmeGpu->archSpecific->complexGridSize[gridIndex],
+                                   &pmeGpu->archSpecific->complexGridCapacity[gridIndex],
+                                   pmeGpu->archSpecific->deviceContext_);
+            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_realGrid[gridIndex],
+                                   newRealGridSize,
+                                   &pmeGpu->archSpecific->realGridSize[gridIndex],
+                                   &pmeGpu->archSpecific->realGridCapacity[gridIndex],
+                                   pmeGpu->archSpecific->deviceContext_);
+        }
+        else
+        {
+            /* A single buffer so that any grid will fit */
+            const int newGridsSize = std::max(newRealGridSize, newComplexGridSize);
+            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_realGrid[gridIndex],
+                                   newGridsSize,
+                                   &pmeGpu->archSpecific->realGridSize[gridIndex],
+                                   &pmeGpu->archSpecific->realGridCapacity[gridIndex],
+                                   pmeGpu->archSpecific->deviceContext_);
+            kernelParamsPtr->grid.d_fourierGrid[gridIndex] = kernelParamsPtr->grid.d_realGrid[gridIndex];
+            pmeGpu->archSpecific->complexGridSize[gridIndex] =
+                    pmeGpu->archSpecific->realGridSize[gridIndex];
+            // the size might get used later for copying the grid
+        }
     }
 }
 
 void pme_gpu_free_grids(const PmeGpu* pmeGpu)
 {
-    if (pmeGpu->archSpecific->performOutOfPlaceFFT)
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
     {
-        freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_fourierGrid);
+        if (pmeGpu->archSpecific->performOutOfPlaceFFT)
+        {
+            freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_fourierGrid[gridIndex]);
+        }
+        freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex]);
     }
-    freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid);
 }
 
 void pme_gpu_clear_grids(const PmeGpu* pmeGpu)
 {
-    clearDeviceBufferAsync(&pmeGpu->kernelParams->grid.d_realGrid, 0,
-                           pmeGpu->archSpecific->realGridSize, pmeGpu->archSpecific->pmeStream);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        clearDeviceBufferAsync(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex],
+                               0,
+                               pmeGpu->archSpecific->realGridSize[gridIndex],
+                               pmeGpu->archSpecific->pmeStream_);
+    }
 }
 
 void pme_gpu_realloc_and_copy_fract_shifts(PmeGpu* pmeGpu)
@@ -406,37 +446,28 @@ void pme_gpu_realloc_and_copy_fract_shifts(PmeGpu* pmeGpu)
 
     const int newFractShiftsSize = cellCount * (nx + ny + nz);
 
-#if GMX_GPU == GMX_GPU_CUDA
-    initParamLookupTable(kernelParamsPtr->grid.d_fractShiftsTable, kernelParamsPtr->fractShiftsTableTexture,
-                         pmeGpu->common->fsh.data(), newFractShiftsSize);
-
-    initParamLookupTable(kernelParamsPtr->grid.d_gridlineIndicesTable,
-                         kernelParamsPtr->gridlineIndicesTableTexture, pmeGpu->common->nn.data(),
-                         newFractShiftsSize);
-#elif GMX_GPU == GMX_GPU_OPENCL
-    // No dedicated texture routines....
-    allocateDeviceBuffer(&kernelParamsPtr->grid.d_fractShiftsTable, newFractShiftsSize,
-                         pmeGpu->archSpecific->context);
-    allocateDeviceBuffer(&kernelParamsPtr->grid.d_gridlineIndicesTable, newFractShiftsSize,
-                         pmeGpu->archSpecific->context);
-    copyToDeviceBuffer(&kernelParamsPtr->grid.d_fractShiftsTable, pmeGpu->common->fsh.data(), 0,
-                       newFractShiftsSize, pmeGpu->archSpecific->pmeStream,
-                       GpuApiCallBehavior::Async, nullptr);
-    copyToDeviceBuffer(&kernelParamsPtr->grid.d_gridlineIndicesTable, pmeGpu->common->nn.data(), 0,
-                       newFractShiftsSize, pmeGpu->archSpecific->pmeStream,
-                       GpuApiCallBehavior::Async, nullptr);
-#endif
+    initParamLookupTable(&kernelParamsPtr->grid.d_fractShiftsTable,
+                         &kernelParamsPtr->fractShiftsTableTexture,
+                         pmeGpu->common->fsh.data(),
+                         newFractShiftsSize,
+                         pmeGpu->archSpecific->deviceContext_);
+
+    initParamLookupTable(&kernelParamsPtr->grid.d_gridlineIndicesTable,
+                         &kernelParamsPtr->gridlineIndicesTableTexture,
+                         pmeGpu->common->nn.data(),
+                         newFractShiftsSize,
+                         pmeGpu->archSpecific->deviceContext_);
 }
 
 void pme_gpu_free_fract_shifts(const PmeGpu* pmeGpu)
 {
     auto* kernelParamsPtr = pmeGpu->kernelParams.get();
-#if GMX_GPU == GMX_GPU_CUDA
-    destroyParamLookupTable(kernelParamsPtr->grid.d_fractShiftsTable,
-                            kernelParamsPtr->fractShiftsTableTexture);
-    destroyParamLookupTable(kernelParamsPtr->grid.d_gridlineIndicesTable,
-                            kernelParamsPtr->gridlineIndicesTableTexture);
-#elif GMX_GPU == GMX_GPU_OPENCL
+#if GMX_GPU_CUDA
+    destroyParamLookupTable(&kernelParamsPtr->grid.d_fractShiftsTable,
+                            &kernelParamsPtr->fractShiftsTableTexture);
+    destroyParamLookupTable(&kernelParamsPtr->grid.d_gridlineIndicesTable,
+                            &kernelParamsPtr->gridlineIndicesTableTexture);
+#elif GMX_GPU_OPENCL || GMX_GPU_SYCL
     freeDeviceBuffer(&kernelParamsPtr->grid.d_fractShiftsTable);
     freeDeviceBuffer(&kernelParamsPtr->grid.d_gridlineIndicesTable);
 #endif
@@ -444,63 +475,99 @@ void pme_gpu_free_fract_shifts(const PmeGpu* pmeGpu)
 
 bool pme_gpu_stream_query(const PmeGpu* pmeGpu)
 {
-    return haveStreamTasksCompleted(pmeGpu->archSpecific->pmeStream);
+    return haveStreamTasksCompleted(pmeGpu->archSpecific->pmeStream_);
 }
 
-void pme_gpu_copy_input_gather_grid(const PmeGpu* pmeGpu, float* h_grid)
+void pme_gpu_copy_input_gather_grid(const PmeGpu* pmeGpu, const float* h_grid, const int gridIndex)
 {
-    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid, h_grid, 0, pmeGpu->archSpecific->realGridSize,
-                       pmeGpu->archSpecific->pmeStream, pmeGpu->settings.transferKind, nullptr);
+    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex],
+                       h_grid,
+                       0,
+                       pmeGpu->archSpecific->realGridSize[gridIndex],
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
-void pme_gpu_copy_output_spread_grid(const PmeGpu* pmeGpu, float* h_grid)
+void pme_gpu_copy_output_spread_grid(const PmeGpu* pmeGpu, float* h_grid, const int gridIndex)
 {
-    copyFromDeviceBuffer(h_grid, &pmeGpu->kernelParams->grid.d_realGrid, 0,
-                         pmeGpu->archSpecific->realGridSize, pmeGpu->archSpecific->pmeStream,
-                         pmeGpu->settings.transferKind, nullptr);
-    pmeGpu->archSpecific->syncSpreadGridD2H.markEvent(pmeGpu->archSpecific->pmeStream);
+    copyFromDeviceBuffer(h_grid,
+                         &pmeGpu->kernelParams->grid.d_realGrid[gridIndex],
+                         0,
+                         pmeGpu->archSpecific->realGridSize[gridIndex],
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
+    pmeGpu->archSpecific->syncSpreadGridD2H.markEvent(pmeGpu->archSpecific->pmeStream_);
 }
 
 void pme_gpu_copy_output_spread_atom_data(const PmeGpu* pmeGpu)
 {
-    const int    alignment       = pme_gpu_get_atoms_per_warp(pmeGpu);
-    const size_t nAtomsPadded    = ((pmeGpu->nAtomsAlloc + alignment - 1) / alignment) * alignment;
-    const size_t splinesCount    = DIM * nAtomsPadded * pmeGpu->common->pme_order;
+    const size_t splinesCount    = DIM * pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order;
     auto*        kernelParamsPtr = pmeGpu->kernelParams.get();
-    copyFromDeviceBuffer(pmeGpu->staging.h_dtheta, &kernelParamsPtr->atoms.d_dtheta, 0, splinesCount,
-                         pmeGpu->archSpecific->pmeStream, pmeGpu->settings.transferKind, nullptr);
-    copyFromDeviceBuffer(pmeGpu->staging.h_theta, &kernelParamsPtr->atoms.d_theta, 0, splinesCount,
-                         pmeGpu->archSpecific->pmeStream, pmeGpu->settings.transferKind, nullptr);
-    copyFromDeviceBuffer(pmeGpu->staging.h_gridlineIndices, &kernelParamsPtr->atoms.d_gridlineIndices,
-                         0, kernelParamsPtr->atoms.nAtoms * DIM, pmeGpu->archSpecific->pmeStream,
-                         pmeGpu->settings.transferKind, nullptr);
+    copyFromDeviceBuffer(pmeGpu->staging.h_dtheta,
+                         &kernelParamsPtr->atoms.d_dtheta,
+                         0,
+                         splinesCount,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
+    copyFromDeviceBuffer(pmeGpu->staging.h_theta,
+                         &kernelParamsPtr->atoms.d_theta,
+                         0,
+                         splinesCount,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
+    copyFromDeviceBuffer(pmeGpu->staging.h_gridlineIndices,
+                         &kernelParamsPtr->atoms.d_gridlineIndices,
+                         0,
+                         kernelParamsPtr->atoms.nAtoms * DIM,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
 }
 
 void pme_gpu_copy_input_gather_atom_data(const PmeGpu* pmeGpu)
 {
-    const int    alignment       = pme_gpu_get_atoms_per_warp(pmeGpu);
-    const size_t nAtomsPadded    = ((pmeGpu->nAtomsAlloc + alignment - 1) / alignment) * alignment;
-    const size_t splinesCount    = DIM * nAtomsPadded * pmeGpu->common->pme_order;
+    const size_t splinesCount    = DIM * pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order;
     auto*        kernelParamsPtr = pmeGpu->kernelParams.get();
-    if (c_usePadding)
-    {
-        // TODO: could clear only the padding and not the whole thing, but this is a test-exclusive code anyway
-        clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_gridlineIndices, 0,
-                               pmeGpu->nAtomsAlloc * DIM, pmeGpu->archSpecific->pmeStream);
-        clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_dtheta, 0,
-                               pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order * DIM,
-                               pmeGpu->archSpecific->pmeStream);
-        clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_theta, 0,
-                               pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order * DIM,
-                               pmeGpu->archSpecific->pmeStream);
-    }
-    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_dtheta, pmeGpu->staging.h_dtheta, 0, splinesCount,
-                       pmeGpu->archSpecific->pmeStream, pmeGpu->settings.transferKind, nullptr);
-    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_theta, pmeGpu->staging.h_theta, 0, splinesCount,
-                       pmeGpu->archSpecific->pmeStream, pmeGpu->settings.transferKind, nullptr);
-    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_gridlineIndices, pmeGpu->staging.h_gridlineIndices,
-                       0, kernelParamsPtr->atoms.nAtoms * DIM, pmeGpu->archSpecific->pmeStream,
-                       pmeGpu->settings.transferKind, nullptr);
+
+    // TODO: could clear only the padding and not the whole thing, but this is a test-exclusive code anyway
+    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_gridlineIndices,
+                           0,
+                           pmeGpu->nAtomsAlloc * DIM,
+                           pmeGpu->archSpecific->pmeStream_);
+    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_dtheta,
+                           0,
+                           pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order * DIM,
+                           pmeGpu->archSpecific->pmeStream_);
+    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_theta,
+                           0,
+                           pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order * DIM,
+                           pmeGpu->archSpecific->pmeStream_);
+
+    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_dtheta,
+                       pmeGpu->staging.h_dtheta,
+                       0,
+                       splinesCount,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
+    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_theta,
+                       pmeGpu->staging.h_theta,
+                       0,
+                       splinesCount,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
+    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_gridlineIndices,
+                       pmeGpu->staging.h_gridlineIndices,
+                       0,
+                       kernelParamsPtr->atoms.nAtoms * DIM,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
 void pme_gpu_sync_spread_grid(const PmeGpu* pmeGpu)
@@ -508,16 +575,16 @@ void pme_gpu_sync_spread_grid(const PmeGpu* pmeGpu)
     pmeGpu->archSpecific->syncSpreadGridD2H.waitForEvent();
 }
 
-void pme_gpu_init_internal(PmeGpu* pmeGpu)
+/*! \brief Internal GPU initialization for PME.
+ *
+ * \param[in]  pmeGpu         GPU PME data.
+ * \param[in]  deviceContext  GPU context.
+ * \param[in]  deviceStream   GPU stream.
+ */
+static void pme_gpu_init_internal(PmeGpu* pmeGpu, const DeviceContext& deviceContext, const DeviceStream& deviceStream)
 {
-#if GMX_GPU == GMX_GPU_CUDA
-    // Prepare to use the device that this PME task was assigned earlier.
-    // Other entities, such as CUDA timing events, are known to implicitly use the device context.
-    CU_RET_ERR(cudaSetDevice(pmeGpu->deviceInfo->id), "Switching to PME CUDA device");
-#endif
-
     /* Allocate the target-specific structures */
-    pmeGpu->archSpecific.reset(new PmeGpuSpecific());
+    pmeGpu->archSpecific.reset(new PmeGpuSpecific(deviceContext, deviceStream));
     pmeGpu->kernelParams.reset(new PmeGpuKernelParams());
 
     pmeGpu->archSpecific->performOutOfPlaceFFT = true;
@@ -526,80 +593,63 @@ void pme_gpu_init_internal(PmeGpu* pmeGpu)
      * TODO: PME could also try to pick up nice grid sizes (with factors of 2, 3, 5, 7).
      */
 
-    // TODO: this is just a convenient reuse because programHandle_ currently is in charge of creating context
-    pmeGpu->archSpecific->context = pmeGpu->programHandle_->impl_->context;
-
-    // timing enabling - TODO put this in gpu_utils (even though generally this is just option handling?) and reuse in NB
-    if (GMX_GPU == GMX_GPU_CUDA)
-    {
-        /* WARNING: CUDA timings are incorrect with multiple streams.
-         *          This is the main reason why they are disabled by default.
-         */
-        // TODO: Consider turning on by default when we can detect nr of streams.
-        pmeGpu->archSpecific->useTiming = (getenv("GMX_ENABLE_GPU_TIMING") != nullptr);
-    }
-    else if (GMX_GPU == GMX_GPU_OPENCL)
-    {
-        pmeGpu->archSpecific->useTiming = (getenv("GMX_DISABLE_GPU_TIMING") == nullptr);
-    }
-
-#if GMX_GPU == GMX_GPU_CUDA
-    pmeGpu->maxGridWidthX = pmeGpu->deviceInfo->prop.maxGridSize[0];
-#elif GMX_GPU == GMX_GPU_OPENCL
-    pmeGpu->maxGridWidthX = INT32_MAX / 2;
+#if GMX_GPU_CUDA
+    pmeGpu->kernelParams->usePipeline       = char(false);
+    pmeGpu->kernelParams->pipelineAtomStart = 0;
+    pmeGpu->kernelParams->pipelineAtomEnd   = 0;
+    pmeGpu->maxGridWidthX                   = deviceContext.deviceInfo().prop.maxGridSize[0];
+#else
+    // Use this path for any non-CUDA GPU acceleration
     // TODO: is there no really global work size limit in OpenCL?
-#endif
-
-    /* Creating a PME GPU stream:
-     * - default high priority with CUDA
-     * - no priorities implemented yet with OpenCL; see #2532
-     */
-#if GMX_GPU == GMX_GPU_CUDA
-    cudaError_t stat;
-    int         highest_priority, lowest_priority;
-    stat = cudaDeviceGetStreamPriorityRange(&lowest_priority, &highest_priority);
-    CU_RET_ERR(stat, "PME cudaDeviceGetStreamPriorityRange failed");
-    stat = cudaStreamCreateWithPriority(&pmeGpu->archSpecific->pmeStream,
-                                        cudaStreamDefault, // cudaStreamNonBlocking,
-                                        highest_priority);
-    CU_RET_ERR(stat, "cudaStreamCreateWithPriority on the PME stream failed");
-#elif GMX_GPU == GMX_GPU_OPENCL
-    cl_command_queue_properties queueProperties =
-            pmeGpu->archSpecific->useTiming ? CL_QUEUE_PROFILING_ENABLE : 0;
-    cl_device_id device_id = pmeGpu->deviceInfo->ocl_gpu_id.ocl_device_id;
-    cl_int       clError;
-    pmeGpu->archSpecific->pmeStream =
-            clCreateCommandQueue(pmeGpu->archSpecific->context, device_id, queueProperties, &clError);
-    if (clError != CL_SUCCESS)
-    {
-        GMX_THROW(gmx::InternalError("Failed to create PME command queue"));
-    }
-#endif
-}
-
-void pme_gpu_destroy_specific(const PmeGpu* pmeGpu)
-{
-#if GMX_GPU == GMX_GPU_CUDA
-    /* Destroy the CUDA stream */
-    cudaError_t stat = cudaStreamDestroy(pmeGpu->archSpecific->pmeStream);
-    CU_RET_ERR(stat, "PME cudaStreamDestroy error");
-#elif GMX_GPU == GMX_GPU_OPENCL
-    cl_int clError = clReleaseCommandQueue(pmeGpu->archSpecific->pmeStream);
-    if (clError != CL_SUCCESS)
-    {
-        gmx_warning("Failed to destroy PME command queue");
-    }
+    pmeGpu->maxGridWidthX = INT32_MAX / 2;
 #endif
 }
 
 void pme_gpu_reinit_3dfft(const PmeGpu* pmeGpu)
 {
-    if (pme_gpu_performs_FFT(pmeGpu))
+    if (pme_gpu_settings(pmeGpu).performGPUFFT)
     {
         pmeGpu->archSpecific->fftSetup.resize(0);
-        for (int i = 0; i < pmeGpu->common->ngrids; i++)
+        const bool         performOutOfPlaceFFT      = pmeGpu->archSpecific->performOutOfPlaceFFT;
+        const bool         allocateGrid              = false;
+        MPI_Comm           comm                      = MPI_COMM_NULL;
+        std::array<int, 1> gridOffsetsInXForEachRank = { 0 };
+        std::array<int, 1> gridOffsetsInYForEachRank = { 0 };
+#if GMX_GPU_CUDA
+        const gmx::FftBackend backend = gmx::FftBackend::Cufft;
+#elif GMX_GPU_OPENCL
+        const gmx::FftBackend backend = gmx::FftBackend::Ocl;
+#elif GMX_GPU_SYCL
+#    if GMX_SYCL_DPCPP && GMX_FFT_MKL
+        const gmx::FftBackend backend = gmx::FftBackend::SyclMkl;
+#    elif GMX_SYCL_HIPSYCL
+        const gmx::FftBackend backend = gmx::FftBackend::SyclRocfft;
+#    else
+        const gmx::FftBackend backend = gmx::FftBackend::Sycl;
+#    endif
+#else
+        GMX_RELEASE_ASSERT(false, "Unknown GPU backend");
+        const gmx::FftBackend backend = gmx::FftBackend::Count;
+#endif
+
+        PmeGpuGridParams& grid = pme_gpu_get_kernel_params_base_ptr(pmeGpu)->grid;
+        for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
         {
-            pmeGpu->archSpecific->fftSetup.push_back(std::make_unique<GpuParallel3dFft>(pmeGpu));
+            pmeGpu->archSpecific->fftSetup.push_back(
+                    std::make_unique<gmx::Gpu3dFft>(backend,
+                                                    allocateGrid,
+                                                    comm,
+                                                    gridOffsetsInXForEachRank,
+                                                    gridOffsetsInYForEachRank,
+                                                    grid.realGridSize[ZZ],
+                                                    performOutOfPlaceFFT,
+                                                    pmeGpu->archSpecific->deviceContext_,
+                                                    pmeGpu->archSpecific->pmeStream_,
+                                                    grid.realGridSize,
+                                                    grid.realGridSizePadded,
+                                                    grid.complexGridSizePadded,
+                                                    &(grid.d_realGrid[gridIndex]),
+                                                    &(grid.d_fourierGrid[gridIndex])));
         }
     }
 }
@@ -609,69 +659,62 @@ void pme_gpu_destroy_3dfft(const PmeGpu* pmeGpu)
     pmeGpu->archSpecific->fftSetup.resize(0);
 }
 
-int getSplineParamFullIndex(int order, int splineIndex, int dimIndex, int atomIndex, int atomsPerWarp)
+void pme_gpu_getEnergyAndVirial(const gmx_pme_t& pme, const float lambda, PmeOutput* output)
 {
-    if (order != c_pmeGpuOrder)
+    const PmeGpu* pmeGpu = pme.gpu;
+
+    GMX_ASSERT(lambda == 1.0 || pmeGpu->common->ngrids == 2,
+               "Invalid combination of lambda and number of grids");
+
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
     {
-        throw order;
+        for (int j = 0; j < c_virialAndEnergyCount; j++)
+        {
+            GMX_ASSERT(std::isfinite(pmeGpu->staging.h_virialAndEnergy[gridIndex][j]),
+                       "PME GPU produces incorrect energy/virial.");
+        }
     }
-    constexpr int fixedOrder = c_pmeGpuOrder;
-    GMX_UNUSED_VALUE(fixedOrder);
-
-    const int atomWarpIndex = atomIndex % atomsPerWarp;
-    const int warpIndex     = atomIndex / atomsPerWarp;
-    int       indexBase, result;
-    switch (atomsPerWarp)
+    for (int dim1 = 0; dim1 < DIM; dim1++)
     {
-        case 1:
-            indexBase = getSplineParamIndexBase<fixedOrder, 1>(warpIndex, atomWarpIndex);
-            result    = getSplineParamIndex<fixedOrder, 1>(indexBase, dimIndex, splineIndex);
-            break;
-
-        case 2:
-            indexBase = getSplineParamIndexBase<fixedOrder, 2>(warpIndex, atomWarpIndex);
-            result    = getSplineParamIndex<fixedOrder, 2>(indexBase, dimIndex, splineIndex);
-            break;
-
-        case 4:
-            indexBase = getSplineParamIndexBase<fixedOrder, 4>(warpIndex, atomWarpIndex);
-            result    = getSplineParamIndex<fixedOrder, 4>(indexBase, dimIndex, splineIndex);
-            break;
-
-        case 8:
-            indexBase = getSplineParamIndexBase<fixedOrder, 8>(warpIndex, atomWarpIndex);
-            result    = getSplineParamIndex<fixedOrder, 8>(indexBase, dimIndex, splineIndex);
-            break;
-
-        default:
-            GMX_THROW(gmx::NotImplementedError(
-                    gmx::formatString("Test function call not unrolled for atomsPerWarp = %d in "
-                                      "getSplineParamFullIndex",
-                                      atomsPerWarp)));
+        for (int dim2 = 0; dim2 < DIM; dim2++)
+        {
+            output->coulombVirial_[dim1][dim2] = 0;
+        }
+    }
+    output->coulombEnergy_ = 0;
+    float scale            = 1.0;
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        if (pmeGpu->common->ngrids == 2)
+        {
+            scale = gridIndex == 0 ? (1.0 - lambda) : lambda;
+        }
+        output->coulombVirial_[XX][XX] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][0];
+        output->coulombVirial_[YY][YY] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][1];
+        output->coulombVirial_[ZZ][ZZ] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][2];
+        output->coulombVirial_[XX][YY] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][3];
+        output->coulombVirial_[YY][XX] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][3];
+        output->coulombVirial_[XX][ZZ] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][4];
+        output->coulombVirial_[ZZ][XX] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][4];
+        output->coulombVirial_[YY][ZZ] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][5];
+        output->coulombVirial_[ZZ][YY] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][5];
+        output->coulombEnergy_ += scale * 0.5F * pmeGpu->staging.h_virialAndEnergy[gridIndex][6];
+    }
+    if (pmeGpu->common->ngrids > 1)
+    {
+        output->coulombDvdl_ = 0.5F
+                               * (pmeGpu->staging.h_virialAndEnergy[FEP_STATE_B][6]
+                                  - pmeGpu->staging.h_virialAndEnergy[FEP_STATE_A][6]);
     }
-    return result;
-}
-
-void pme_gpu_getEnergyAndVirial(const gmx_pme_t& pme, PmeOutput* output)
-{
-    const PmeGpu* pmeGpu = pme.gpu;
-    for (int j = 0; j < c_virialAndEnergyCount; j++)
-    {
-        GMX_ASSERT(std::isfinite(pmeGpu->staging.h_virialAndEnergy[j]),
-                   "PME GPU produces incorrect energy/virial.");
-    }
-
-    unsigned int j                 = 0;
-    output->coulombVirial_[XX][XX] = 0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[YY][YY] = 0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[ZZ][ZZ] = 0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[XX][YY] = output->coulombVirial_[YY][XX] =
-            0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[XX][ZZ] = output->coulombVirial_[ZZ][XX] =
-            0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[YY][ZZ] = output->coulombVirial_[ZZ][YY] =
-            0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombEnergy_ = 0.5F * pmeGpu->staging.h_virialAndEnergy[j++];
 }
 
 /*! \brief Sets the force-related members in \p output
@@ -688,22 +731,19 @@ static void pme_gpu_getForceOutput(PmeGpu* pmeGpu, PmeOutput* output)
     }
 }
 
-PmeOutput pme_gpu_getOutput(const gmx_pme_t& pme, const int flags)
+PmeOutput pme_gpu_getOutput(const gmx_pme_t& pme, const bool computeEnergyAndVirial, const real lambdaQ)
 {
-    PmeGpu*    pmeGpu                      = pme.gpu;
-    const bool haveComputedEnergyAndVirial = (flags & GMX_PME_CALC_ENER_VIR) != 0;
+    PmeGpu* pmeGpu = pme.gpu;
 
     PmeOutput output;
 
     pme_gpu_getForceOutput(pmeGpu, &output);
 
-    // The caller knows from the flags that the energy and the virial are not usable
-    // on the else branch
-    if (haveComputedEnergyAndVirial)
+    if (computeEnergyAndVirial)
     {
-        if (pme_gpu_performs_solve(pmeGpu))
+        if (pme_gpu_settings(pmeGpu).performGPUSolve)
         {
-            pme_gpu_getEnergyAndVirial(pme, &output);
+            pme_gpu_getEnergyAndVirial(pme, lambdaQ, &output);
         }
         else
         {
@@ -745,9 +785,13 @@ void pme_gpu_update_input_box(PmeGpu gmx_unused* pmeGpu, const matrix gmx_unused
 static void pme_gpu_reinit_grids(PmeGpu* pmeGpu)
 {
     auto* kernelParamsPtr = pme_gpu_get_kernel_params_base_ptr(pmeGpu);
+
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+
     kernelParamsPtr->grid.ewaldFactor =
             (M_PI * M_PI) / (pmeGpu->common->ewaldcoeff_q * pmeGpu->common->ewaldcoeff_q);
-
     /* The grid size variants */
     for (int i = 0; i < DIM; i++)
     {
@@ -762,20 +806,23 @@ static void pme_gpu_reinit_grids(PmeGpu* pmeGpu)
         kernelParamsPtr->grid.complexGridSizePadded[i] = kernelParamsPtr->grid.realGridSize[i];
     }
     /* FFT: n real elements correspond to (n / 2 + 1) complex elements in minor dimension */
-    if (!pme_gpu_performs_FFT(pmeGpu))
+    if (!pme_gpu_settings(pmeGpu).performGPUFFT)
     {
         // This allows for GPU spreading grid and CPU fftgrid to have the same layout, so that we can copy the data directly
         kernelParamsPtr->grid.realGridSizePadded[ZZ] =
                 (kernelParamsPtr->grid.realGridSize[ZZ] / 2 + 1) * 2;
     }
-
     /* GPU FFT: n real elements correspond to (n / 2 + 1) complex elements in minor dimension */
     kernelParamsPtr->grid.complexGridSize[ZZ] /= 2;
     kernelParamsPtr->grid.complexGridSize[ZZ]++;
     kernelParamsPtr->grid.complexGridSizePadded[ZZ] = kernelParamsPtr->grid.complexGridSize[ZZ];
 
     pme_gpu_realloc_and_copy_fract_shifts(pmeGpu);
-    pme_gpu_realloc_and_copy_bspline_values(pmeGpu);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        pme_gpu_realloc_and_copy_bspline_values(pmeGpu, gridIndex);
+    }
+
     pme_gpu_realloc_grids(pmeGpu);
     pme_gpu_reinit_3dfft(pmeGpu);
 }
@@ -796,7 +843,7 @@ static void pme_gpu_copy_common_data_from(const gmx_pme_t* pme)
     /* TODO: Consider refactoring the CPU PME code to use the same structure,
      * so that this function becomes 2 lines */
     PmeGpu* pmeGpu               = pme->gpu;
-    pmeGpu->common->ngrids       = pme->ngrids;
+    pmeGpu->common->ngrids       = pme->bFEP_q ? 2 : 1;
     pmeGpu->common->epsilon_r    = pme->epsilon_r;
     pmeGpu->common->ewaldcoeff_q = pme->ewaldcoeff_q;
     pmeGpu->common->nk[XX]       = pme->nkx;
@@ -822,18 +869,42 @@ static void pme_gpu_copy_common_data_from(const gmx_pme_t* pme)
     pmeGpu->common->nn.insert(pmeGpu->common->nn.end(), pme->nnz, pme->nnz + cellCount * pme->nkz);
     pmeGpu->common->runMode       = pme->runMode;
     pmeGpu->common->isRankPmeOnly = !pme->bPPnode;
-    pmeGpu->common->boxScaler     = pme->boxScaler;
+    pmeGpu->common->boxScaler     = pme->boxScaler.get();
+}
+
+/*! \libinternal \brief
+ * uses heuristics to select the best performing PME gather and scatter kernels
+ *
+ * \param[in,out] pmeGpu         The PME GPU structure.
+ */
+static void pme_gpu_select_best_performing_pme_spreadgather_kernels(PmeGpu* pmeGpu)
+{
+    if (GMX_GPU_CUDA && pmeGpu->kernelParams->atoms.nAtoms > c_pmeGpuPerformanceAtomLimit)
+    {
+        pmeGpu->settings.threadsPerAtom     = ThreadsPerAtom::Order;
+        pmeGpu->settings.recalculateSplines = true;
+    }
+    else
+    {
+        pmeGpu->settings.threadsPerAtom     = ThreadsPerAtom::OrderSquared;
+        pmeGpu->settings.recalculateSplines = false;
+    }
 }
 
+
 /*! \libinternal \brief
  * Initializes the PME GPU data at the beginning of the run.
  * TODO: this should become PmeGpu::PmeGpu()
  *
  * \param[in,out] pme            The PME structure.
- * \param[in,out] gpuInfo        The GPU information structure.
- * \param[in]     pmeGpuProgram  The handle to the program/kernel data created outside (e.g. in unit tests/runner)
+ * \param[in]     deviceContext  The GPU context.
+ * \param[in]     deviceStream   The GPU stream.
+ * \param[in,out] pmeGpuProgram  The handle to the program/kernel data created outside (e.g. in unit tests/runner)
  */
-static void pme_gpu_init(gmx_pme_t* pme, const gmx_device_info_t* gpuInfo, PmeGpuProgramHandle pmeGpuProgram)
+static void pme_gpu_init(gmx_pme_t*           pme,
+                         const DeviceContext& deviceContext,
+                         const DeviceStream&  deviceStream,
+                         const PmeGpuProgram* pmeGpuProgram)
 {
     pme->gpu       = new PmeGpu();
     PmeGpu* pmeGpu = pme->gpu;
@@ -842,7 +913,7 @@ static void pme_gpu_init(gmx_pme_t* pme, const gmx_device_info_t* gpuInfo, PmeGp
 
     /* These settings are set here for the whole run; dynamic ones are set in pme_gpu_reinit() */
     /* A convenience variable. */
-    pmeGpu->settings.useDecomposition = (pme->nnodes == 1);
+    pmeGpu->settings.useDecomposition = (pme->nnodes != 1);
     /* TODO: CPU gather with GPU spread is broken due to different theta/dtheta layout. */
     pmeGpu->settings.performGPUGather = true;
     // By default GPU-side reduction is off (explicitly set here for tests, otherwise reset per-step)
@@ -850,82 +921,20 @@ static void pme_gpu_init(gmx_pme_t* pme, const gmx_device_info_t* gpuInfo, PmeGp
 
     pme_gpu_set_testing(pmeGpu, false);
 
-    pmeGpu->deviceInfo = gpuInfo;
     GMX_ASSERT(pmeGpuProgram != nullptr, "GPU kernels must be already compiled");
     pmeGpu->programHandle_ = pmeGpuProgram;
 
     pmeGpu->initializedClfftLibrary_ = std::make_unique<gmx::ClfftInitializer>();
 
-    pme_gpu_init_internal(pmeGpu);
-    pme_gpu_alloc_energy_virial(pmeGpu);
+    pme_gpu_init_internal(pmeGpu, deviceContext, deviceStream);
 
     pme_gpu_copy_common_data_from(pme);
+    pme_gpu_alloc_energy_virial(pmeGpu);
 
     GMX_ASSERT(pmeGpu->common->epsilon_r != 0.0F, "PME GPU: bad electrostatic coefficient");
 
     auto* kernelParamsPtr               = pme_gpu_get_kernel_params_base_ptr(pmeGpu);
-    kernelParamsPtr->constants.elFactor = ONE_4PI_EPS0 / pmeGpu->common->epsilon_r;
-}
-
-void pme_gpu_transform_spline_atom_data(const PmeGpu*      pmeGpu,
-                                        const PmeAtomComm* atc,
-                                        PmeSplineDataType  type,
-                                        int                dimIndex,
-                                        PmeLayoutTransform transform)
-{
-    // The GPU atom spline data is laid out in a different way currently than the CPU one.
-    // This function converts the data from GPU to CPU layout (in the host memory).
-    // It is only intended for testing purposes so far.
-    // Ideally we should use similar layouts on CPU and GPU if we care about mixed modes and their
-    // performance (e.g. spreading on GPU, gathering on CPU).
-    GMX_RELEASE_ASSERT(atc->nthread == 1, "Only the serial PME data layout is supported");
-    const uintmax_t threadIndex  = 0;
-    const auto      atomCount    = pme_gpu_get_kernel_params_base_ptr(pmeGpu)->atoms.nAtoms;
-    const auto      atomsPerWarp = pme_gpu_get_atoms_per_warp(pmeGpu);
-    const auto      pmeOrder     = pmeGpu->common->pme_order;
-    GMX_ASSERT(pmeOrder == c_pmeGpuOrder, "Only PME order 4 is implemented");
-
-    real*  cpuSplineBuffer;
-    float* h_splineBuffer;
-    switch (type)
-    {
-        case PmeSplineDataType::Values:
-            cpuSplineBuffer = atc->spline[threadIndex].theta.coefficients[dimIndex];
-            h_splineBuffer  = pmeGpu->staging.h_theta;
-            break;
-
-        case PmeSplineDataType::Derivatives:
-            cpuSplineBuffer = atc->spline[threadIndex].dtheta.coefficients[dimIndex];
-            h_splineBuffer  = pmeGpu->staging.h_dtheta;
-            break;
-
-        default: GMX_THROW(gmx::InternalError("Unknown spline data type"));
-    }
-
-    for (auto atomIndex = 0; atomIndex < atomCount; atomIndex++)
-    {
-        for (auto orderIndex = 0; orderIndex < pmeOrder; orderIndex++)
-        {
-            const auto gpuValueIndex =
-                    getSplineParamFullIndex(pmeOrder, orderIndex, dimIndex, atomIndex, atomsPerWarp);
-            const auto cpuValueIndex = atomIndex * pmeOrder + orderIndex;
-            GMX_ASSERT(cpuValueIndex < atomCount * pmeOrder,
-                       "Atom spline data index out of bounds (while transforming GPU data layout "
-                       "for host)");
-            switch (transform)
-            {
-                case PmeLayoutTransform::GpuToHost:
-                    cpuSplineBuffer[cpuValueIndex] = h_splineBuffer[gpuValueIndex];
-                    break;
-
-                case PmeLayoutTransform::HostToGpu:
-                    h_splineBuffer[gpuValueIndex] = cpuSplineBuffer[cpuValueIndex];
-                    break;
-
-                default: GMX_THROW(gmx::InternalError("Unknown layout transform"));
-            }
-        }
-    }
+    kernelParamsPtr->constants.elFactor = gmx::c_one4PiEps0 / pmeGpu->common->epsilon_r;
 }
 
 void pme_gpu_get_real_grid_sizes(const PmeGpu* pmeGpu, gmx::IVec* gridSize, gmx::IVec* paddedGridSize)
@@ -941,17 +950,21 @@ void pme_gpu_get_real_grid_sizes(const PmeGpu* pmeGpu, gmx::IVec* gridSize, gmx:
     }
 }
 
-void pme_gpu_reinit(gmx_pme_t* pme, const gmx_device_info_t* gpuInfo, PmeGpuProgramHandle pmeGpuProgram)
+void pme_gpu_reinit(gmx_pme_t*           pme,
+                    const DeviceContext* deviceContext,
+                    const DeviceStream*  deviceStream,
+                    const PmeGpuProgram* pmeGpuProgram)
 {
-    if (!pme_gpu_active(pme))
-    {
-        return;
-    }
+    GMX_ASSERT(pme != nullptr, "Need valid PME object");
 
     if (!pme->gpu)
     {
+        GMX_RELEASE_ASSERT(deviceContext != nullptr,
+                           "Device context can not be nullptr when setting up PME on GPU.");
+        GMX_RELEASE_ASSERT(deviceStream != nullptr,
+                           "Device stream can not be nullptr when setting up PME on GPU.");
         /* First-time initialization */
-        pme_gpu_init(pme, gpuInfo, pmeGpuProgram);
+        pme_gpu_init(pme, *deviceContext, *deviceStream, pmeGpuProgram);
     }
     else
     {
@@ -960,7 +973,7 @@ void pme_gpu_reinit(gmx_pme_t* pme, const gmx_device_info_t* gpuInfo, PmeGpuProg
     }
     /* GPU FFT will only get used for a single rank.*/
     pme->gpu->settings.performGPUFFT =
-            (pme->gpu->common->runMode == PmeRunMode::GPU) && !pme_gpu_uses_dd(pme->gpu);
+            (pme->gpu->common->runMode == PmeRunMode::GPU) && !pme->gpu->settings.useDecomposition;
     pme->gpu->settings.performGPUSolve = (pme->gpu->common->runMode == PmeRunMode::GPU);
 
     /* Reinit active timers */
@@ -990,29 +1003,41 @@ void pme_gpu_destroy(PmeGpu* pmeGpu)
 
     pme_gpu_destroy_3dfft(pmeGpu);
 
-    /* Free the GPU-framework specific data last */
-    pme_gpu_destroy_specific(pmeGpu);
-
     delete pmeGpu;
 }
 
-void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* charges)
+void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* chargesA, const real* chargesB)
 {
     auto* kernelParamsPtr         = pme_gpu_get_kernel_params_base_ptr(pmeGpu);
     kernelParamsPtr->atoms.nAtoms = nAtoms;
-    const int alignment           = pme_gpu_get_atom_data_alignment(pmeGpu);
-    pmeGpu->nAtomsPadded          = ((nAtoms + alignment - 1) / alignment) * alignment;
-    const int  nAtomsAlloc        = c_usePadding ? pmeGpu->nAtomsPadded : nAtoms;
-    const bool haveToRealloc =
-            (pmeGpu->nAtomsAlloc < nAtomsAlloc); /* This check might be redundant, but is logical */
-    pmeGpu->nAtomsAlloc = nAtomsAlloc;
+    const int  block_size         = pme_gpu_get_atom_data_block_size();
+    const int  nAtomsNewPadded    = ((nAtoms + block_size - 1) / block_size) * block_size;
+    const bool haveToRealloc      = (pmeGpu->nAtomsAlloc < nAtomsNewPadded);
+    pmeGpu->nAtomsAlloc           = nAtomsNewPadded;
 
 #if GMX_DOUBLE
     GMX_RELEASE_ASSERT(false, "Only single precision supported");
     GMX_UNUSED_VALUE(charges);
 #else
-    pme_gpu_realloc_and_copy_input_coefficients(pmeGpu, reinterpret_cast<const float*>(charges));
+    int gridIndex = 0;
     /* Could also be checked for haveToRealloc, but the copy always needs to be performed */
+    pme_gpu_realloc_and_copy_input_coefficients(pmeGpu, reinterpret_cast<const float*>(chargesA), gridIndex);
+    gridIndex++;
+    if (chargesB != nullptr)
+    {
+        pme_gpu_realloc_and_copy_input_coefficients(
+                pmeGpu, reinterpret_cast<const float*>(chargesB), gridIndex);
+    }
+    else
+    {
+        /* Fill the second set of coefficients with chargesA as well to be able to avoid
+         * conditionals in the GPU kernels */
+        /* FIXME: This should be avoided by making a separate templated version of the
+         * relevant kernel(s) (probably only pme_gather_kernel). That would require a
+         * reduction of the current number of templated parameters of that kernel. */
+        pme_gpu_realloc_and_copy_input_coefficients(
+                pmeGpu, reinterpret_cast<const float*>(chargesA), gridIndex);
+    }
 #endif
 
     if (haveToRealloc)
@@ -1021,11 +1046,31 @@ void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* charges)
         pme_gpu_realloc_spline_data(pmeGpu);
         pme_gpu_realloc_grid_indices(pmeGpu);
     }
+    pme_gpu_select_best_performing_pme_spreadgather_kernels(pmeGpu);
 }
 
-void pme_gpu_3dfft(const PmeGpu* pmeGpu, gmx_fft_direction dir, int grid_index)
+/*! \internal \brief
+ * Returns raw timing event from the corresponding GpuRegionTimer (if timings are enabled).
+ * In CUDA result can be nullptr stub, per GpuRegionTimer implementation.
+ *
+ * \param[in] pmeGpu         The PME GPU data structure.
+ * \param[in] pmeStageId     The PME GPU stage gtPME_ index from the enum in src/gromacs/timing/gpu_timing.h
+ */
+static CommandEvent* pme_gpu_fetch_timing_event(const PmeGpu* pmeGpu, PmeStage pmeStageId)
 {
-    int timerId = (dir == GMX_FFT_REAL_TO_COMPLEX) ? gtPME_FFT_R2C : gtPME_FFT_C2R;
+    CommandEvent* timingEvent = nullptr;
+    if (pme_gpu_timings_enabled(pmeGpu))
+    {
+        GMX_ASSERT(pmeStageId < PmeStage::Count, "Wrong PME GPU timing event index");
+        timingEvent = pmeGpu->archSpecific->timingEvents[pmeStageId].fetchNextEvent();
+    }
+    return timingEvent;
+}
+
+void pme_gpu_3dfft(const PmeGpu* pmeGpu, gmx_fft_direction dir, const int grid_index)
+{
+    PmeStage timerId = (dir == GMX_FFT_REAL_TO_COMPLEX) ? PmeStage::FftTransformR2C
+                                                        : PmeStage::FftTransformC2R;
 
     pme_gpu_start_timing(pmeGpu, timerId);
     pmeGpu->archSpecific->fftSetup[grid_index]->perform3dFft(
@@ -1053,34 +1098,66 @@ std::pair<int, int> inline pmeGpuCreateGrid(const PmeGpu* pmeGpu, int blockCount
  * Returns a pointer to appropriate spline and spread kernel based on the input bool values
  *
  * \param[in]  pmeGpu                   The PME GPU structure.
- * \param[in]  useOrderThreadsPerAtom   bool controlling if we should use order or order*order threads per atom
+ * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  writeSplinesToGlobal     bool controlling if we should write spline data to global memory
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
-static auto selectSplineAndSpreadKernelPtr(const PmeGpu* pmeGpu, bool useOrderThreadsPerAtom, bool writeSplinesToGlobal)
+static auto selectSplineAndSpreadKernelPtr(const PmeGpu*  pmeGpu,
+                                           ThreadsPerAtom threadsPerAtom,
+                                           bool           writeSplinesToGlobal,
+                                           const int      numGrids)
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (writeSplinesToGlobal)
     {
-        if (useOrderThreadsPerAtom)
+        if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplines;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesSingle;
+            }
         }
     }
     else
     {
-        if (useOrderThreadsPerAtom)
+        if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelSingle;
+            }
         }
     }
 
@@ -1091,25 +1168,43 @@ static auto selectSplineAndSpreadKernelPtr(const PmeGpu* pmeGpu, bool useOrderTh
  * Returns a pointer to appropriate spline kernel based on the input bool values
  *
  * \param[in]  pmeGpu                   The PME GPU structure.
- * \param[in]  useOrderThreadsPerAtom   bool controlling if we should use order or order*order threads per atom
+ * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  writeSplinesToGlobal     bool controlling if we should write spline data to global memory
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
-static auto selectSplineKernelPtr(const PmeGpu* pmeGpu, bool useOrderThreadsPerAtom, bool gmx_unused writeSplinesToGlobal)
+static auto selectSplineKernelPtr(const PmeGpu*   pmeGpu,
+                                  ThreadsPerAtom  threadsPerAtom,
+                                  bool gmx_unused writeSplinesToGlobal,
+                                  const int       numGrids)
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     GMX_ASSERT(
             writeSplinesToGlobal,
             "Spline data should always be written to global memory when just calculating splines");
 
-    if (useOrderThreadsPerAtom)
+    if (threadsPerAtom == ThreadsPerAtom::Order)
     {
-        kernelPtr = pmeGpu->programHandle_->impl_->splineKernelThPerAtom4;
+        if (numGrids == 2)
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelThPerAtom4Dual;
+        }
+        else
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelThPerAtom4Single;
+        }
     }
     else
     {
-        kernelPtr = pmeGpu->programHandle_->impl_->splineKernel;
+        if (numGrids == 2)
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelDual;
+        }
+        else
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelSingle;
+        }
     }
     return kernelPtr;
 }
@@ -1118,51 +1213,89 @@ static auto selectSplineKernelPtr(const PmeGpu* pmeGpu, bool useOrderThreadsPerA
  * Returns a pointer to appropriate spread kernel based on the input bool values
  *
  * \param[in]  pmeGpu                   The PME GPU structure.
- * \param[in]  useOrderThreadsPerAtom   bool controlling if we should use order or order*order threads per atom
+ * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  writeSplinesToGlobal     bool controlling if we should write spline data to global memory
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
-static auto selectSpreadKernelPtr(const PmeGpu* pmeGpu, bool useOrderThreadsPerAtom, bool writeSplinesToGlobal)
+static auto selectSpreadKernelPtr(const PmeGpu*  pmeGpu,
+                                  ThreadsPerAtom threadsPerAtom,
+                                  bool           writeSplinesToGlobal,
+                                  const int      numGrids)
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (writeSplinesToGlobal)
     {
-        if (useOrderThreadsPerAtom)
+        if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->spreadKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelSingle;
+            }
         }
     }
     else
     {
         /* if we are not saving the spline data we need to recalculate it
            using the spline and spread Kernel */
-        if (useOrderThreadsPerAtom)
+        if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelSingle;
+            }
         }
     }
     return kernelPtr;
 }
 
-void pme_gpu_spread(const PmeGpu*         pmeGpu,
-                    GpuEventSynchronizer* xReadyOnDevice,
-                    int gmx_unused gridIndex,
-                    real*          h_grid,
-                    bool           computeSplines,
-                    bool           spreadCharges)
+void pme_gpu_spread(const PmeGpu*                  pmeGpu,
+                    GpuEventSynchronizer*          xReadyOnDevice,
+                    real**                         h_grids,
+                    bool                           computeSplines,
+                    bool                           spreadCharges,
+                    const real                     lambda,
+                    const bool                     useGpuDirectComm,
+                    gmx::PmeCoordinateReceiverGpu* pmeCoordinateReceiverGpu)
 {
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+
     GMX_ASSERT(computeSplines || spreadCharges,
                "PME spline/spread kernel has invalid input (nothing to do)");
-    const auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+    auto* kernelParamsPtr = pmeGpu->kernelParams.get();
     GMX_ASSERT(kernelParamsPtr->atoms.nAtoms > 0, "No atom data in PME GPU spread");
 
     const size_t blockSize = pmeGpu->programHandle_->impl_->spreadWorkGroupSize;
@@ -1170,12 +1303,16 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
     const int order = pmeGpu->common->pme_order;
     GMX_ASSERT(order == c_pmeGpuOrder, "Only PME order 4 is implemented");
     const bool writeGlobal = pmeGpu->settings.copyAllOutputs;
-#if GMX_GPU == GMX_GPU_OPENCL
-    GMX_ASSERT(!c_useOrderThreadsPerAtom, "Only 16 threads per atom supported in OpenCL");
-    GMX_ASSERT(!c_recalculateSplines, "Recalculating splines not supported in OpenCL");
-#endif
-    const int atomsPerBlock = c_useOrderThreadsPerAtom ? blockSize / c_pmeSpreadGatherThreadsPerAtom4ThPerAtom
-                                                       : blockSize / c_pmeSpreadGatherThreadsPerAtom;
+    const int  threadsPerAtom =
+            (pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::Order ? order : order * order);
+    const bool recalculateSplines = pmeGpu->settings.recalculateSplines;
+
+    GMX_ASSERT(!GMX_GPU_OPENCL || pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::OrderSquared,
+               "Only 16 threads per atom supported in OpenCL");
+    GMX_ASSERT(!GMX_GPU_OPENCL || !recalculateSplines,
+               "Recalculating splines not supported in OpenCL");
+
+    const int atomsPerBlock = blockSize / threadsPerAtom;
 
     // TODO: pick smaller block size in runtime if needed
     // (e.g. on 660 Ti where 50% occupancy is ~25% faster than 100% occupancy with RNAse (~17.8k atoms))
@@ -1183,98 +1320,215 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
     // TODO: test varying block sizes on modern arch-s as well
     // TODO: also consider using cudaFuncSetCacheConfig() for preferring shared memory on older architectures
     //(for spline data mostly)
-    GMX_ASSERT(!c_usePadding || !(c_pmeAtomDataAlignment % atomsPerBlock),
+    GMX_ASSERT(!(c_pmeAtomDataBlockSize % atomsPerBlock),
                "inconsistent atom data padding vs. spreading block size");
 
     // Ensure that coordinates are ready on the device before launching spread;
-    // only needed with CUDA on PP+PME ranks, not on separate PME ranks, in unit tests
-    // nor in OpenCL as these cases use a single stream (hence xReadyOnDevice == nullptr).
-    GMX_ASSERT(xReadyOnDevice != nullptr || (GMX_GPU != GMX_GPU_CUDA)
-                       || pmeGpu->common->isRankPmeOnly || pme_gpu_is_testing(pmeGpu),
+    // only needed on PP+PME ranks, not on separate PME ranks, in unit tests
+    // as these cases use a single stream (hence xReadyOnDevice == nullptr).
+    GMX_ASSERT(xReadyOnDevice != nullptr || pmeGpu->common->isRankPmeOnly
+                       || pme_gpu_settings(pmeGpu).copyAllOutputs,
                "Need a valid coordinate synchronizer on PP+PME ranks with CUDA.");
+
     if (xReadyOnDevice)
     {
-        xReadyOnDevice->enqueueWaitEvent(pmeGpu->archSpecific->pmeStream);
+        xReadyOnDevice->enqueueWaitEvent(pmeGpu->archSpecific->pmeStream_);
     }
 
-    const int blockCount = pmeGpu->nAtomsPadded / atomsPerBlock;
+    const int blockCount = pmeGpu->nAtomsAlloc / atomsPerBlock;
     auto      dimGrid    = pmeGpuCreateGrid(pmeGpu, blockCount);
 
+    if (pmeGpu->common->ngrids == 1)
+    {
+        kernelParamsPtr->current.scale = 1.0;
+    }
+    else
+    {
+        kernelParamsPtr->current.scale = 1.0 - lambda;
+    }
+
     KernelLaunchConfig config;
     config.blockSize[0] = order;
-    config.blockSize[1] = c_useOrderThreadsPerAtom ? 1 : order;
+    config.blockSize[1] = (pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::Order ? 1 : order);
     config.blockSize[2] = atomsPerBlock;
     config.gridSize[0]  = dimGrid.first;
     config.gridSize[1]  = dimGrid.second;
-    config.stream       = pmeGpu->archSpecific->pmeStream;
 
-    int                                timingId;
+    PmeStage                           timingId;
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
+    const bool writeGlobalOrSaveSplines          = writeGlobal || (!recalculateSplines);
     if (computeSplines)
     {
         if (spreadCharges)
         {
-            timingId  = gtPME_SPLINEANDSPREAD;
-            kernelPtr = selectSplineAndSpreadKernelPtr(pmeGpu, c_useOrderThreadsPerAtom,
-                                                       writeGlobal || (!c_recalculateSplines));
+            timingId  = PmeStage::SplineAndSpread;
+            kernelPtr = selectSplineAndSpreadKernelPtr(pmeGpu,
+                                                       pmeGpu->settings.threadsPerAtom,
+                                                       writeGlobalOrSaveSplines,
+                                                       pmeGpu->common->ngrids);
         }
         else
         {
-            timingId  = gtPME_SPLINE;
-            kernelPtr = selectSplineKernelPtr(pmeGpu, c_useOrderThreadsPerAtom,
-                                              writeGlobal || (!c_recalculateSplines));
+            timingId  = PmeStage::Spline;
+            kernelPtr = selectSplineKernelPtr(pmeGpu,
+                                              pmeGpu->settings.threadsPerAtom,
+                                              writeGlobalOrSaveSplines,
+                                              pmeGpu->common->ngrids);
         }
     }
     else
     {
-        timingId  = gtPME_SPREAD;
-        kernelPtr = selectSpreadKernelPtr(pmeGpu, c_useOrderThreadsPerAtom,
-                                          writeGlobal || (!c_recalculateSplines));
+        timingId  = PmeStage::Spread;
+        kernelPtr = selectSpreadKernelPtr(
+                pmeGpu, pmeGpu->settings.threadsPerAtom, writeGlobalOrSaveSplines, pmeGpu->common->ngrids);
     }
 
 
     pme_gpu_start_timing(pmeGpu, timingId);
     auto* timingEvent = pme_gpu_fetch_timing_event(pmeGpu, timingId);
+
+    kernelParamsPtr->usePipeline = char(computeSplines && spreadCharges && useGpuDirectComm
+                                        && (pmeCoordinateReceiverGpu->ppCommNumSenderRanks() > 1)
+                                        && !writeGlobalOrSaveSplines);
+    if (kernelParamsPtr->usePipeline != 0)
+    {
+        int numStagesInPipeline = pmeCoordinateReceiverGpu->ppCommNumSenderRanks();
+
+        for (int i = 0; i < numStagesInPipeline; i++)
+        {
+            int senderRank;
+            if (useGpuDirectComm)
+            {
+                senderRank = pmeCoordinateReceiverGpu->synchronizeOnCoordinatesFromPpRank(
+                        i, *(pmeCoordinateReceiverGpu->ppCommStream(i)));
+            }
+            else
+            {
+                senderRank = i;
+            }
+
+            // set kernel configuration options specific to this stage of the pipeline
+            std::tie(kernelParamsPtr->pipelineAtomStart, kernelParamsPtr->pipelineAtomEnd) =
+                    pmeCoordinateReceiverGpu->ppCommAtomRange(senderRank);
+            const int blockCount       = static_cast<int>(std::ceil(
+                    static_cast<float>(kernelParamsPtr->pipelineAtomEnd - kernelParamsPtr->pipelineAtomStart)
+                    / atomsPerBlock));
+            auto      dimGrid          = pmeGpuCreateGrid(pmeGpu, blockCount);
+            config.gridSize[0]         = dimGrid.first;
+            config.gridSize[1]         = dimGrid.second;
+            DeviceStream* launchStream = pmeCoordinateReceiverGpu->ppCommStream(senderRank);
+
+
 #if c_canEmbedBuffers
-    const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
+            const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
 #else
-    const auto kernelArgs = prepareGpuKernelArguments(
-            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->atoms.d_theta,
-            &kernelParamsPtr->atoms.d_dtheta, &kernelParamsPtr->atoms.d_gridlineIndices,
-            &kernelParamsPtr->grid.d_realGrid, &kernelParamsPtr->grid.d_fractShiftsTable,
-            &kernelParamsPtr->grid.d_gridlineIndicesTable, &kernelParamsPtr->atoms.d_coefficients,
-            &kernelParamsPtr->atoms.d_coordinates);
+            const auto kernelArgs =
+                    prepareGpuKernelArguments(kernelPtr,
+                                              config,
+                                              kernelParamsPtr,
+                                              &kernelParamsPtr->atoms.d_theta,
+                                              &kernelParamsPtr->atoms.d_dtheta,
+                                              &kernelParamsPtr->atoms.d_gridlineIndices,
+                                              &kernelParamsPtr->grid.d_realGrid[FEP_STATE_A],
+                                              &kernelParamsPtr->grid.d_realGrid[FEP_STATE_B],
+                                              &kernelParamsPtr->grid.d_fractShiftsTable,
+                                              &kernelParamsPtr->grid.d_gridlineIndicesTable,
+                                              &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_A],
+                                              &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_B],
+                                              &kernelParamsPtr->atoms.d_coordinates);
 #endif
 
-    launchGpuKernel(kernelPtr, config, timingEvent, "PME spline/spread", kernelArgs);
+            launchGpuKernel(kernelPtr, config, *launchStream, timingEvent, "PME spline/spread", kernelArgs);
+        }
+        // Set dependencies for PME stream on all pipeline streams
+        for (int i = 0; i < pmeCoordinateReceiverGpu->ppCommNumSenderRanks(); i++)
+        {
+            GpuEventSynchronizer event;
+            event.markEvent(*(pmeCoordinateReceiverGpu->ppCommStream(i)));
+            event.enqueueWaitEvent(pmeGpu->archSpecific->pmeStream_);
+        }
+    }
+    else // pipelining is not in use
+    {
+        if (useGpuDirectComm) // Sync all PME-PP communications to PME stream
+        {
+            pmeCoordinateReceiverGpu->synchronizeOnCoordinatesFromAllPpRanks(pmeGpu->archSpecific->pmeStream_);
+        }
+
+#if c_canEmbedBuffers
+        const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
+#else
+        const auto kernelArgs =
+                prepareGpuKernelArguments(kernelPtr,
+                                          config,
+                                          kernelParamsPtr,
+                                          &kernelParamsPtr->atoms.d_theta,
+                                          &kernelParamsPtr->atoms.d_dtheta,
+                                          &kernelParamsPtr->atoms.d_gridlineIndices,
+                                          &kernelParamsPtr->grid.d_realGrid[FEP_STATE_A],
+                                          &kernelParamsPtr->grid.d_realGrid[FEP_STATE_B],
+                                          &kernelParamsPtr->grid.d_fractShiftsTable,
+                                          &kernelParamsPtr->grid.d_gridlineIndicesTable,
+                                          &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_A],
+                                          &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_B],
+                                          &kernelParamsPtr->atoms.d_coordinates);
+#endif
+
+        launchGpuKernel(kernelPtr,
+                        config,
+                        pmeGpu->archSpecific->pmeStream_,
+                        timingEvent,
+                        "PME spline/spread",
+                        kernelArgs);
+    }
+
     pme_gpu_stop_timing(pmeGpu, timingId);
 
-    const bool copyBackGrid =
-            spreadCharges && (pme_gpu_is_testing(pmeGpu) || !pme_gpu_performs_FFT(pmeGpu));
+    const auto& settings    = pmeGpu->settings;
+    const bool copyBackGrid = spreadCharges && (!settings.performGPUFFT || settings.copyAllOutputs);
     if (copyBackGrid)
     {
-        pme_gpu_copy_output_spread_grid(pmeGpu, h_grid);
+        for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+        {
+            float* h_grid = h_grids[gridIndex];
+            pme_gpu_copy_output_spread_grid(pmeGpu, h_grid, gridIndex);
+        }
     }
     const bool copyBackAtomData =
-            computeSplines && (pme_gpu_is_testing(pmeGpu) || !pme_gpu_performs_gather(pmeGpu));
+            computeSplines && (!settings.performGPUGather || settings.copyAllOutputs);
     if (copyBackAtomData)
     {
         pme_gpu_copy_output_spread_atom_data(pmeGpu);
     }
 }
 
-void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrdering, bool computeEnergyAndVirial)
+void pme_gpu_solve(const PmeGpu* pmeGpu,
+                   const int     gridIndex,
+                   t_complex*    h_grid,
+                   GridOrdering  gridOrdering,
+                   bool          computeEnergyAndVirial)
 {
-    const bool copyInputAndOutputGrid = pme_gpu_is_testing(pmeGpu) || !pme_gpu_performs_FFT(pmeGpu);
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+    GMX_ASSERT(gridIndex < pmeGpu->common->ngrids,
+               "Invalid combination of gridIndex and number of grids");
+
+    const auto& settings               = pmeGpu->settings;
+    const bool  copyInputAndOutputGrid = !settings.performGPUFFT || settings.copyAllOutputs;
 
     auto* kernelParamsPtr = pmeGpu->kernelParams.get();
 
     float* h_gridFloat = reinterpret_cast<float*>(h_grid);
     if (copyInputAndOutputGrid)
     {
-        copyToDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid, h_gridFloat, 0,
-                           pmeGpu->archSpecific->complexGridSize, pmeGpu->archSpecific->pmeStream,
-                           pmeGpu->settings.transferKind, nullptr);
+        copyToDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid[gridIndex],
+                           h_gridFloat,
+                           0,
+                           pmeGpu->archSpecific->complexGridSize[gridIndex],
+                           pmeGpu->archSpecific->pmeStream_,
+                           pmeGpu->settings.transferKind,
+                           nullptr);
     }
 
     int majorDim = -1, middleDim = -1, minorDim = -1;
@@ -1309,10 +1563,10 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
     {
         cellsPerBlock = (gridLineSize + blocksPerGridLine - 1) / blocksPerGridLine;
     }
-    const int warpSize  = pmeGpu->programHandle_->impl_->warpSize;
+    const int warpSize  = pmeGpu->programHandle_->warpSize();
     const int blockSize = (cellsPerBlock + warpSize - 1) / warpSize * warpSize;
 
-    static_assert(GMX_GPU != GMX_GPU_CUDA || c_solveMaxWarpsPerBlock / 2 >= 4,
+    static_assert(!GMX_GPU_CUDA || c_solveMaxWarpsPerBlock / 2 >= 4,
                   "The CUDA solve energy kernels needs at least 4 warps. "
                   "Here we launch at least half of the max warps.");
 
@@ -1323,19 +1577,34 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
     config.gridSize[1] = (pmeGpu->kernelParams->grid.complexGridSize[middleDim] + gridLinesPerBlock - 1)
                          / gridLinesPerBlock;
     config.gridSize[2] = pmeGpu->kernelParams->grid.complexGridSize[majorDim];
-    config.stream      = pmeGpu->archSpecific->pmeStream;
 
-    int                                timingId  = gtPME_SOLVE;
+    PmeStage                           timingId  = PmeStage::Solve;
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (gridOrdering == GridOrdering::YZX)
     {
-        kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveYZXEnergyKernel
-                                           : pmeGpu->programHandle_->impl_->solveYZXKernel;
+        if (gridIndex == 0)
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveYZXEnergyKernelA
+                                               : pmeGpu->programHandle_->impl_->solveYZXKernelA;
+        }
+        else
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveYZXEnergyKernelB
+                                               : pmeGpu->programHandle_->impl_->solveYZXKernelB;
+        }
     }
     else if (gridOrdering == GridOrdering::XYZ)
     {
-        kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveXYZEnergyKernel
-                                           : pmeGpu->programHandle_->impl_->solveXYZKernel;
+        if (gridIndex == 0)
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveXYZEnergyKernelA
+                                               : pmeGpu->programHandle_->impl_->solveXYZKernelA;
+        }
+        else
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveXYZEnergyKernelB
+                                               : pmeGpu->programHandle_->impl_->solveXYZKernelB;
+        }
     }
 
     pme_gpu_start_timing(pmeGpu, timingId);
@@ -1343,25 +1612,37 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
 #if c_canEmbedBuffers
     const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
 #else
-    const auto kernelArgs = prepareGpuKernelArguments(
-            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->grid.d_splineModuli,
-            &kernelParamsPtr->constants.d_virialAndEnergy, &kernelParamsPtr->grid.d_fourierGrid);
+    const auto kernelArgs =
+            prepareGpuKernelArguments(kernelPtr,
+                                      config,
+                                      kernelParamsPtr,
+                                      &kernelParamsPtr->grid.d_splineModuli[gridIndex],
+                                      &kernelParamsPtr->constants.d_virialAndEnergy[gridIndex],
+                                      &kernelParamsPtr->grid.d_fourierGrid[gridIndex]);
 #endif
-    launchGpuKernel(kernelPtr, config, timingEvent, "PME solve", kernelArgs);
+    launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent, "PME solve", kernelArgs);
     pme_gpu_stop_timing(pmeGpu, timingId);
 
     if (computeEnergyAndVirial)
     {
-        copyFromDeviceBuffer(pmeGpu->staging.h_virialAndEnergy,
-                             &kernelParamsPtr->constants.d_virialAndEnergy, 0, c_virialAndEnergyCount,
-                             pmeGpu->archSpecific->pmeStream, pmeGpu->settings.transferKind, nullptr);
+        copyFromDeviceBuffer(pmeGpu->staging.h_virialAndEnergy[gridIndex],
+                             &kernelParamsPtr->constants.d_virialAndEnergy[gridIndex],
+                             0,
+                             c_virialAndEnergyCount,
+                             pmeGpu->archSpecific->pmeStream_,
+                             pmeGpu->settings.transferKind,
+                             nullptr);
     }
 
     if (copyInputAndOutputGrid)
     {
-        copyFromDeviceBuffer(h_gridFloat, &kernelParamsPtr->grid.d_fourierGrid, 0,
-                             pmeGpu->archSpecific->complexGridSize, pmeGpu->archSpecific->pmeStream,
-                             pmeGpu->settings.transferKind, nullptr);
+        copyFromDeviceBuffer(h_gridFloat,
+                             &kernelParamsPtr->grid.d_fourierGrid[gridIndex],
+                             0,
+                             pmeGpu->archSpecific->complexGridSize[gridIndex],
+                             pmeGpu->archSpecific->pmeStream_,
+                             pmeGpu->settings.transferKind,
+                             nullptr);
     }
 }
 
@@ -1369,68 +1650,91 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
  * Returns a pointer to appropriate gather kernel based on the inputvalues
  *
  * \param[in]  pmeGpu                   The PME GPU structure.
- * \param[in]  useOrderThreadsPerAtom   bool controlling if we should use order or order*order threads per atom
+ * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  readSplinesFromGlobal    bool controlling if we should write spline data to global memory
- * \param[in]  forceTreatment           Controls if the forces from the gather should increment or replace the input forces.
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
-inline auto selectGatherKernelPtr(const PmeGpu*          pmeGpu,
-                                  bool                   useOrderThreadsPerAtom,
-                                  bool                   readSplinesFromGlobal,
-                                  PmeForceOutputHandling forceTreatment)
+inline auto selectGatherKernelPtr(const PmeGpu*  pmeGpu,
+                                  ThreadsPerAtom threadsPerAtom,
+                                  bool           readSplinesFromGlobal,
+                                  const int      numGrids)
 
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
 
     if (readSplinesFromGlobal)
     {
-        if (useOrderThreadsPerAtom)
+        if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = (forceTreatment == PmeForceOutputHandling::Set)
-                                ? pmeGpu->programHandle_->impl_->gatherKernelReadSplinesThPerAtom4
-                                : pmeGpu->programHandle_->impl_->gatherReduceWithInputKernelReadSplinesThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = (forceTreatment == PmeForceOutputHandling::Set)
-                                ? pmeGpu->programHandle_->impl_->gatherKernelReadSplines
-                                : pmeGpu->programHandle_->impl_->gatherReduceWithInputKernelReadSplines;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesSingle;
+            }
         }
     }
     else
     {
-        if (useOrderThreadsPerAtom)
+        if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = (forceTreatment == PmeForceOutputHandling::Set)
-                                ? pmeGpu->programHandle_->impl_->gatherKernelThPerAtom4
-                                : pmeGpu->programHandle_->impl_->gatherReduceWithInputKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = (forceTreatment == PmeForceOutputHandling::Set)
-                                ? pmeGpu->programHandle_->impl_->gatherKernel
-                                : pmeGpu->programHandle_->impl_->gatherReduceWithInputKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelSingle;
+            }
         }
     }
     return kernelPtr;
 }
 
-
-void pme_gpu_gather(PmeGpu* pmeGpu, PmeForceOutputHandling forceTreatment, const float* h_grid)
+void pme_gpu_gather(PmeGpu* pmeGpu, real** h_grids, const float lambda)
 {
-    /* Copying the input CPU forces for reduction */
-    if (forceTreatment != PmeForceOutputHandling::Set)
-    {
-        pme_gpu_copy_input_forces(pmeGpu);
-    }
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+
+    const auto& settings = pmeGpu->settings;
 
-    if (!pme_gpu_performs_FFT(pmeGpu) || pme_gpu_is_testing(pmeGpu))
+    if (!settings.performGPUFFT || settings.copyAllOutputs)
     {
-        pme_gpu_copy_input_gather_grid(pmeGpu, const_cast<float*>(h_grid));
+        for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+        {
+            float* h_grid = const_cast<float*>(h_grids[gridIndex]);
+            pme_gpu_copy_input_gather_grid(pmeGpu, h_grid, gridIndex);
+        }
     }
 
-    if (pme_gpu_is_testing(pmeGpu))
+    if (settings.copyAllOutputs)
     {
         pme_gpu_copy_input_gather_atom_data(pmeGpu);
     }
@@ -1438,55 +1742,76 @@ void pme_gpu_gather(PmeGpu* pmeGpu, PmeForceOutputHandling forceTreatment, const
     /* Set if we have unit tests */
     const bool   readGlobal = pmeGpu->settings.copyAllOutputs;
     const size_t blockSize  = pmeGpu->programHandle_->impl_->gatherWorkGroupSize;
-#if GMX_GPU == GMX_GPU_OPENCL
-    GMX_ASSERT(!c_useOrderThreadsPerAtom, "Only 16 threads per atom supported in OpenCL");
-    GMX_ASSERT(!c_recalculateSplines, "Recalculating splines not supported in OpenCL");
-#endif
-    const int atomsPerBlock = c_useOrderThreadsPerAtom ? blockSize / c_pmeSpreadGatherThreadsPerAtom4ThPerAtom
-                                                       : blockSize / c_pmeSpreadGatherThreadsPerAtom;
+    const int    order      = pmeGpu->common->pme_order;
+    GMX_ASSERT(order == c_pmeGpuOrder, "Only PME order 4 is implemented");
+    const int threadsPerAtom =
+            (pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::Order ? order : order * order);
+    const bool recalculateSplines = pmeGpu->settings.recalculateSplines;
+
+    GMX_ASSERT(!GMX_GPU_OPENCL || pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::OrderSquared,
+               "Only 16 threads per atom supported in OpenCL");
+    GMX_ASSERT(!GMX_GPU_OPENCL || !recalculateSplines,
+               "Recalculating splines not supported in OpenCL");
+
+    const int atomsPerBlock = blockSize / threadsPerAtom;
 
-    GMX_ASSERT(!c_usePadding || !(c_pmeAtomDataAlignment % atomsPerBlock),
+    GMX_ASSERT(!(c_pmeAtomDataBlockSize % atomsPerBlock),
                "inconsistent atom data padding vs. gathering block size");
 
-    const int blockCount = pmeGpu->nAtomsPadded / atomsPerBlock;
+    const int blockCount = pmeGpu->nAtomsAlloc / atomsPerBlock;
     auto      dimGrid    = pmeGpuCreateGrid(pmeGpu, blockCount);
 
-    const int order = pmeGpu->common->pme_order;
-    GMX_ASSERT(order == c_pmeGpuOrder, "Only PME order 4 is implemented");
-
     KernelLaunchConfig config;
     config.blockSize[0] = order;
-    config.blockSize[1] = c_useOrderThreadsPerAtom ? 1 : order;
+    config.blockSize[1] = (pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::Order ? 1 : order);
     config.blockSize[2] = atomsPerBlock;
     config.gridSize[0]  = dimGrid.first;
     config.gridSize[1]  = dimGrid.second;
-    config.stream       = pmeGpu->archSpecific->pmeStream;
 
     // TODO test different cache configs
 
-    int                                timingId  = gtPME_GATHER;
-    PmeGpuProgramImpl::PmeKernelHandle kernelPtr = selectGatherKernelPtr(
-            pmeGpu, c_useOrderThreadsPerAtom, readGlobal || (!c_recalculateSplines), forceTreatment);
+    PmeStage                           timingId = PmeStage::Gather;
+    PmeGpuProgramImpl::PmeKernelHandle kernelPtr =
+            selectGatherKernelPtr(pmeGpu,
+                                  pmeGpu->settings.threadsPerAtom,
+                                  readGlobal || (!recalculateSplines),
+                                  pmeGpu->common->ngrids);
     // TODO design kernel selection getters and make PmeGpu a friend of PmeGpuProgramImpl
 
     pme_gpu_start_timing(pmeGpu, timingId);
-    auto*       timingEvent     = pme_gpu_fetch_timing_event(pmeGpu, timingId);
-    const auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+    auto* timingEvent     = pme_gpu_fetch_timing_event(pmeGpu, timingId);
+    auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+    if (pmeGpu->common->ngrids == 1)
+    {
+        kernelParamsPtr->current.scale = 1.0;
+    }
+    else
+    {
+        kernelParamsPtr->current.scale = 1.0 - lambda;
+    }
+
 #if c_canEmbedBuffers
     const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
 #else
-    const auto kernelArgs = prepareGpuKernelArguments(
-            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->atoms.d_coefficients,
-            &kernelParamsPtr->grid.d_realGrid, &kernelParamsPtr->atoms.d_theta,
-            &kernelParamsPtr->atoms.d_dtheta, &kernelParamsPtr->atoms.d_gridlineIndices,
-            &kernelParamsPtr->atoms.d_forces);
+    const auto kernelArgs =
+            prepareGpuKernelArguments(kernelPtr,
+                                      config,
+                                      kernelParamsPtr,
+                                      &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_A],
+                                      &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_B],
+                                      &kernelParamsPtr->grid.d_realGrid[FEP_STATE_A],
+                                      &kernelParamsPtr->grid.d_realGrid[FEP_STATE_B],
+                                      &kernelParamsPtr->atoms.d_theta,
+                                      &kernelParamsPtr->atoms.d_dtheta,
+                                      &kernelParamsPtr->atoms.d_gridlineIndices,
+                                      &kernelParamsPtr->atoms.d_forces);
 #endif
-    launchGpuKernel(kernelPtr, config, timingEvent, "PME gather", kernelArgs);
+    launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent, "PME gather", kernelArgs);
     pme_gpu_stop_timing(pmeGpu, timingId);
 
     if (pmeGpu->settings.useGpuForceReduction)
     {
-        pmeGpu->archSpecific->pmeForcesReady.markEvent(pmeGpu->archSpecific->pmeStream);
+        pmeGpu->archSpecific->pmeForcesReady.markEvent(pmeGpu->archSpecific->pmeStream_);
     }
     else
     {
@@ -1494,16 +1819,7 @@ void pme_gpu_gather(PmeGpu* pmeGpu, PmeForceOutputHandling forceTreatment, const
     }
 }
 
-DeviceBuffer<float> pme_gpu_get_kernelparam_coordinates(const PmeGpu* pmeGpu)
-{
-    GMX_ASSERT(pmeGpu && pmeGpu->kernelParams,
-               "PME GPU device buffer was requested in non-GPU build or before the GPU PME was "
-               "initialized.");
-
-    return pmeGpu->kernelParams->atoms.d_coordinates;
-}
-
-void* pme_gpu_get_kernelparam_forces(const PmeGpu* pmeGpu)
+DeviceBuffer<gmx::RVec> pme_gpu_get_kernelparam_forces(const PmeGpu* pmeGpu)
 {
     if (pmeGpu && pmeGpu->kernelParams)
     {
@@ -1511,42 +1827,11 @@ void* pme_gpu_get_kernelparam_forces(const PmeGpu* pmeGpu)
     }
     else
     {
-        return nullptr;
+        return DeviceBuffer<gmx::RVec>{};
     }
 }
 
-/*! \brief Check the validity of the device buffer.
- *
- * Checks if the buffer is not nullptr and, when possible, if it is big enough.
- *
- * \todo Split and move this function to gpu_utils.
- *
- * \param[in] buffer        Device buffer to be checked.
- * \param[in] requiredSize  Number of elements that the buffer will have to accommodate.
- *
- * \returns If the device buffer can be set.
- */
-template<typename T>
-static bool checkDeviceBuffer(gmx_unused DeviceBuffer<T> buffer, gmx_unused int requiredSize)
-{
-#if GMX_GPU == GMX_GPU_CUDA
-    GMX_ASSERT(buffer != nullptr, "The device pointer is nullptr");
-    return buffer != nullptr;
-#elif GMX_GPU == GMX_GPU_OPENCL
-    size_t size;
-    int    retval = clGetMemObjectInfo(buffer, CL_MEM_SIZE, sizeof(size), &size, nullptr);
-    GMX_ASSERT(retval == CL_SUCCESS,
-               gmx::formatString("clGetMemObjectInfo failed with error code #%d", retval).c_str());
-    GMX_ASSERT(static_cast<int>(size) >= requiredSize,
-               "Number of atoms in device buffer is smaller then required size.");
-    return retval == CL_SUCCESS && static_cast<int>(size) >= requiredSize;
-#elif GMX_GPU == GMX_GPU_NONE
-    GMX_ASSERT(false, "Setter for device-side coordinates was called in non-GPU build.");
-    return false;
-#endif
-}
-
-void pme_gpu_set_kernelparam_coordinates(const PmeGpu* pmeGpu, DeviceBuffer<float> d_x)
+void pme_gpu_set_kernelparam_coordinates(const PmeGpu* pmeGpu, DeviceBuffer<gmx::RVec> d_x)
 {
     GMX_ASSERT(pmeGpu && pmeGpu->kernelParams,
                "PME GPU device buffer can not be set in non-GPU builds or before the GPU PME was "
@@ -1558,30 +1843,6 @@ void pme_gpu_set_kernelparam_coordinates(const PmeGpu* pmeGpu, DeviceBuffer<floa
     pmeGpu->kernelParams->atoms.d_coordinates = d_x;
 }
 
-void* pme_gpu_get_stream(const PmeGpu* pmeGpu)
-{
-    if (pmeGpu)
-    {
-        return static_cast<void*>(&pmeGpu->archSpecific->pmeStream);
-    }
-    else
-    {
-        return nullptr;
-    }
-}
-
-void* pme_gpu_get_context(const PmeGpu* pmeGpu)
-{
-    if (pmeGpu)
-    {
-        return static_cast<void*>(&pmeGpu->archSpecific->context);
-    }
-    else
-    {
-        return nullptr;
-    }
-}
-
 GpuEventSynchronizer* pme_gpu_get_forces_ready_synchronizer(const PmeGpu* pmeGpu)
 {
     if (pmeGpu && pmeGpu->kernelParams)