Unify coordinate copy handling across GPU platforms
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gpu_internal.cpp
index 9c984025256fcfc96c78e9d8357c63d7d0ddeef1..5e5bfd8ebeac7cdd9b6cf22ebff2f2dbf0cfc1ab 100644 (file)
@@ -1,7 +1,8 @@
 /*
  * This file is part of the GROMACS molecular simulation package.
  *
- * Copyright (c) 2016,2017,2018,2019,2020, by the GROMACS development team, led by
+ * Copyright (c) 2016,2017,2018,2019,2020 by the GROMACS development team.
+ * Copyright (c) 2021, by the GROMACS development team, led by
  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
  * and including many others, as listed in the AUTHORS file in the
  * top-level source directory and at http://www.gromacs.org.
 #include <string>
 
 #include "gromacs/ewald/ewald_utils.h"
+#include "gromacs/fft/gpu_3dfft.h"
+#include "gromacs/gpu_utils/device_context.h"
+#include "gromacs/gpu_utils/device_stream.h"
 #include "gromacs/gpu_utils/gpu_utils.h"
+#include "gromacs/gpu_utils/pmalloc.h"
+#if GMX_GPU_SYCL
+#    include "gromacs/gpu_utils/syclutils.h"
+#endif
+#include "gromacs/hardware/device_information.h"
 #include "gromacs/math/invertmatrix.h"
 #include "gromacs/math/units.h"
 #include "gromacs/timing/gpu_timing.h"
 #include "gromacs/utility/gmxassert.h"
 #include "gromacs/utility/logger.h"
 #include "gromacs/utility/stringutil.h"
+#include "gromacs/ewald/pme.h"
+#include "gromacs/ewald/pme_coordinate_receiver_gpu.h"
 
-#if GMX_GPU == GMX_GPU_CUDA
-#    include "gromacs/gpu_utils/pmalloc_cuda.h"
-
+#if GMX_GPU_CUDA
 #    include "pme.cuh"
-#elif GMX_GPU == GMX_GPU_OPENCL
-#    include "gromacs/gpu_utils/gmxopencl.h"
 #endif
 
-#include "gromacs/ewald/pme.h"
-
-#include "pme_gpu_3dfft.h"
+#include "pme_gpu_calculate_splines.h"
 #include "pme_gpu_constants.h"
 #include "pme_gpu_program_impl.h"
 #include "pme_gpu_timings.h"
 #include "pme_gpu_types.h"
 #include "pme_gpu_types_host.h"
 #include "pme_gpu_types_host_impl.h"
-#include "pme_gpu_utils.h"
 #include "pme_grid.h"
 #include "pme_internal.h"
 #include "pme_solve.h"
@@ -91,7 +95,7 @@
 /*! \brief
  * CUDA only
  * Atom limit above which it is advantageous to turn on the
- * recalcuating of the splines in the gather and using less threads per atom in the spline and spread
+ * recalculating of the splines in the gather and using less threads per atom in the spline and spread
  */
 constexpr int c_pmeGpuPerformanceAtomLimit = 23000;
 
@@ -108,60 +112,74 @@ static PmeGpuKernelParamsBase* pme_gpu_get_kernel_params_base_ptr(const PmeGpu*
     return kernelParamsPtr;
 }
 
-int pme_gpu_get_atom_data_alignment(const PmeGpu* /*unused*/)
-{
-    // TODO: this can be simplified, as c_pmeAtomDataAlignment is now constant
-    if (c_usePadding)
-    {
-        return c_pmeAtomDataAlignment;
-    }
-    else
-    {
-        return 0;
-    }
-}
+/*! \brief
+ * Atom data block size (in terms of number of atoms).
+ * This is the least common multiple of number of atoms processed by
+ * a single block/workgroup of the spread and gather kernels.
+ * The GPU atom data buffers must be padded, which means that
+ * the numbers of atoms used for determining the size of the memory
+ * allocation must be divisible by this.
+ */
+constexpr int c_pmeAtomDataBlockSize = 64;
 
-int pme_gpu_get_atoms_per_warp(const PmeGpu* pmeGpu)
+int pme_gpu_get_atom_data_block_size()
 {
-    if (pmeGpu->settings.useOrderThreadsPerAtom)
-    {
-        return pmeGpu->programHandle_->impl_->warpSize / c_pmeSpreadGatherThreadsPerAtom4ThPerAtom;
-    }
-    else
-    {
-        return pmeGpu->programHandle_->impl_->warpSize / c_pmeSpreadGatherThreadsPerAtom;
-    }
+    return c_pmeAtomDataBlockSize;
 }
 
 void pme_gpu_synchronize(const PmeGpu* pmeGpu)
 {
-    gpuStreamSynchronize(pmeGpu->archSpecific->pmeStream);
+    pmeGpu->archSpecific->pmeStream_.synchronize();
 }
 
 void pme_gpu_alloc_energy_virial(PmeGpu* pmeGpu)
 {
     const size_t energyAndVirialSize = c_virialAndEnergyCount * sizeof(float);
-    allocateDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy, c_virialAndEnergyCount,
-                         pmeGpu->archSpecific->context);
-    pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_virialAndEnergy), energyAndVirialSize);
+
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        allocateDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy[gridIndex],
+                             c_virialAndEnergyCount,
+                             pmeGpu->archSpecific->deviceContext_);
+        pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_virialAndEnergy[gridIndex]), energyAndVirialSize);
+    }
 }
 
 void pme_gpu_free_energy_virial(PmeGpu* pmeGpu)
 {
-    freeDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy);
-    pfree(pmeGpu->staging.h_virialAndEnergy);
-    pmeGpu->staging.h_virialAndEnergy = nullptr;
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        freeDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy[gridIndex]);
+        pfree(pmeGpu->staging.h_virialAndEnergy[gridIndex]);
+        pmeGpu->staging.h_virialAndEnergy[gridIndex] = nullptr;
+    }
 }
 
 void pme_gpu_clear_energy_virial(const PmeGpu* pmeGpu)
 {
-    clearDeviceBufferAsync(&pmeGpu->kernelParams->constants.d_virialAndEnergy, 0,
-                           c_virialAndEnergyCount, pmeGpu->archSpecific->pmeStream);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        clearDeviceBufferAsync(&pmeGpu->kernelParams->constants.d_virialAndEnergy[gridIndex],
+                               0,
+                               c_virialAndEnergyCount,
+                               pmeGpu->archSpecific->pmeStream_);
+    }
 }
 
-void pme_gpu_realloc_and_copy_bspline_values(PmeGpu* pmeGpu)
+void pme_gpu_realloc_and_copy_bspline_values(PmeGpu* pmeGpu, const int gridIndex)
 {
-    const int splineValuesOffset[DIM] = { 0, pmeGpu->kernelParams->grid.realGridSize[XX],
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+    GMX_ASSERT(gridIndex < pmeGpu->common->ngrids,
+               "Invalid combination of gridIndex and number of grids");
+
+    const int splineValuesOffset[DIM] = { 0,
+                                          pmeGpu->kernelParams->grid.realGridSize[XX],
                                           pmeGpu->kernelParams->grid.realGridSize[XX]
                                                   + pmeGpu->kernelParams->grid.realGridSize[YY] };
     memcpy(&pmeGpu->kernelParams->grid.splineValuesOffset, &splineValuesOffset, sizeof(splineValuesOffset));
@@ -169,41 +187,53 @@ void pme_gpu_realloc_and_copy_bspline_values(PmeGpu* pmeGpu)
     const int newSplineValuesSize = pmeGpu->kernelParams->grid.realGridSize[XX]
                                     + pmeGpu->kernelParams->grid.realGridSize[YY]
                                     + pmeGpu->kernelParams->grid.realGridSize[ZZ];
-    const bool shouldRealloc = (newSplineValuesSize > pmeGpu->archSpecific->splineValuesSize);
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli, newSplineValuesSize,
-                           &pmeGpu->archSpecific->splineValuesSize,
-                           &pmeGpu->archSpecific->splineValuesSizeAlloc, pmeGpu->archSpecific->context);
+    const bool shouldRealloc = (newSplineValuesSize > pmeGpu->archSpecific->splineValuesSize[gridIndex]);
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli[gridIndex],
+                           newSplineValuesSize,
+                           &pmeGpu->archSpecific->splineValuesSize[gridIndex],
+                           &pmeGpu->archSpecific->splineValuesCapacity[gridIndex],
+                           pmeGpu->archSpecific->deviceContext_);
     if (shouldRealloc)
     {
         /* Reallocate the host buffer */
-        pfree(pmeGpu->staging.h_splineModuli);
-        pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_splineModuli),
+        pfree(pmeGpu->staging.h_splineModuli[gridIndex]);
+        pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_splineModuli[gridIndex]),
                 newSplineValuesSize * sizeof(float));
     }
     for (int i = 0; i < DIM; i++)
     {
-        memcpy(pmeGpu->staging.h_splineModuli + splineValuesOffset[i],
-               pmeGpu->common->bsp_mod[i].data(), pmeGpu->common->bsp_mod[i].size() * sizeof(float));
+        memcpy(pmeGpu->staging.h_splineModuli[gridIndex] + splineValuesOffset[i],
+               pmeGpu->common->bsp_mod[i].data(),
+               pmeGpu->common->bsp_mod[i].size() * sizeof(float));
     }
     /* TODO: pin original buffer instead! */
-    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli, pmeGpu->staging.h_splineModuli,
-                       0, newSplineValuesSize, pmeGpu->archSpecific->pmeStream,
-                       pmeGpu->settings.transferKind, nullptr);
+    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli[gridIndex],
+                       pmeGpu->staging.h_splineModuli[gridIndex],
+                       0,
+                       newSplineValuesSize,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
 void pme_gpu_free_bspline_values(const PmeGpu* pmeGpu)
 {
-    pfree(pmeGpu->staging.h_splineModuli);
-    freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        pfree(pmeGpu->staging.h_splineModuli[gridIndex]);
+        freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli[gridIndex]);
+    }
 }
 
 void pme_gpu_realloc_forces(PmeGpu* pmeGpu)
 {
-    const size_t newForcesSize = pmeGpu->nAtomsAlloc * DIM;
+    const size_t newForcesSize = pmeGpu->nAtomsAlloc;
     GMX_ASSERT(newForcesSize > 0, "Bad number of atoms in PME GPU");
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces, newForcesSize,
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces,
+                           newForcesSize,
                            &pmeGpu->archSpecific->forcesSize,
-                           &pmeGpu->archSpecific->forcesSizeAlloc, pmeGpu->archSpecific->context);
+                           &pmeGpu->archSpecific->forcesSizeAlloc,
+                           pmeGpu->archSpecific->deviceContext_);
     pmeGpu->staging.h_forces.reserveWithPadding(pmeGpu->nAtomsAlloc);
     pmeGpu->staging.h_forces.resizeWithPadding(pmeGpu->kernelParams->atoms.nAtoms);
 }
@@ -216,65 +246,85 @@ void pme_gpu_free_forces(const PmeGpu* pmeGpu)
 void pme_gpu_copy_input_forces(PmeGpu* pmeGpu)
 {
     GMX_ASSERT(pmeGpu->kernelParams->atoms.nAtoms > 0, "Bad number of atoms in PME GPU");
-    float* h_forcesFloat = reinterpret_cast<float*>(pmeGpu->staging.h_forces.data());
-    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces, h_forcesFloat, 0,
-                       DIM * pmeGpu->kernelParams->atoms.nAtoms, pmeGpu->archSpecific->pmeStream,
-                       pmeGpu->settings.transferKind, nullptr);
+    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces,
+                       pmeGpu->staging.h_forces.data(),
+                       0,
+                       pmeGpu->kernelParams->atoms.nAtoms,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
 void pme_gpu_copy_output_forces(PmeGpu* pmeGpu)
 {
     GMX_ASSERT(pmeGpu->kernelParams->atoms.nAtoms > 0, "Bad number of atoms in PME GPU");
-    float* h_forcesFloat = reinterpret_cast<float*>(pmeGpu->staging.h_forces.data());
-    copyFromDeviceBuffer(h_forcesFloat, &pmeGpu->kernelParams->atoms.d_forces, 0,
-                         DIM * pmeGpu->kernelParams->atoms.nAtoms, pmeGpu->archSpecific->pmeStream,
-                         pmeGpu->settings.transferKind, nullptr);
+    copyFromDeviceBuffer(pmeGpu->staging.h_forces.data(),
+                         &pmeGpu->kernelParams->atoms.d_forces,
+                         0,
+                         pmeGpu->kernelParams->atoms.nAtoms,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
 }
 
-void pme_gpu_realloc_and_copy_input_coefficients(PmeGpu* pmeGpu, const float* h_coefficients)
+void pme_gpu_realloc_and_copy_input_coefficients(const PmeGpu* pmeGpu,
+                                                 const float*  h_coefficients,
+                                                 const int     gridIndex)
 {
     GMX_ASSERT(h_coefficients, "Bad host-side charge buffer in PME GPU");
     const size_t newCoefficientsSize = pmeGpu->nAtomsAlloc;
     GMX_ASSERT(newCoefficientsSize > 0, "Bad number of atoms in PME GPU");
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients, newCoefficientsSize,
-                           &pmeGpu->archSpecific->coefficientsSize,
-                           &pmeGpu->archSpecific->coefficientsSizeAlloc, pmeGpu->archSpecific->context);
-    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients,
-                       const_cast<float*>(h_coefficients), 0, pmeGpu->kernelParams->atoms.nAtoms,
-                       pmeGpu->archSpecific->pmeStream, pmeGpu->settings.transferKind, nullptr);
-    if (c_usePadding)
-    {
-        const size_t paddingIndex = pmeGpu->kernelParams->atoms.nAtoms;
-        const size_t paddingCount = pmeGpu->nAtomsAlloc - paddingIndex;
-        if (paddingCount > 0)
-        {
-            clearDeviceBufferAsync(&pmeGpu->kernelParams->atoms.d_coefficients, paddingIndex,
-                                   paddingCount, pmeGpu->archSpecific->pmeStream);
-        }
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex],
+                           newCoefficientsSize,
+                           &pmeGpu->archSpecific->coefficientsSize[gridIndex],
+                           &pmeGpu->archSpecific->coefficientsCapacity[gridIndex],
+                           pmeGpu->archSpecific->deviceContext_);
+    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex],
+                       const_cast<float*>(h_coefficients),
+                       0,
+                       pmeGpu->kernelParams->atoms.nAtoms,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
+
+    const size_t paddingIndex = pmeGpu->kernelParams->atoms.nAtoms;
+    const size_t paddingCount = pmeGpu->nAtomsAlloc - paddingIndex;
+    if (paddingCount > 0)
+    {
+        clearDeviceBufferAsync(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex],
+                               paddingIndex,
+                               paddingCount,
+                               pmeGpu->archSpecific->pmeStream_);
     }
 }
 
 void pme_gpu_free_coefficients(const PmeGpu* pmeGpu)
 {
-    freeDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        freeDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex]);
+    }
 }
 
 void pme_gpu_realloc_spline_data(PmeGpu* pmeGpu)
 {
-    const int    order        = pmeGpu->common->pme_order;
-    const int    alignment    = pme_gpu_get_atoms_per_warp(pmeGpu);
-    const size_t nAtomsPadded = ((pmeGpu->nAtomsAlloc + alignment - 1) / alignment) * alignment;
-    const int    newSplineDataSize = DIM * order * nAtomsPadded;
+    const int order             = pmeGpu->common->pme_order;
+    const int newSplineDataSize = DIM * order * pmeGpu->nAtomsAlloc;
     GMX_ASSERT(newSplineDataSize > 0, "Bad number of atoms in PME GPU");
     /* Two arrays of the same size */
     const bool shouldRealloc        = (newSplineDataSize > pmeGpu->archSpecific->splineDataSize);
     int        currentSizeTemp      = pmeGpu->archSpecific->splineDataSize;
     int        currentSizeTempAlloc = pmeGpu->archSpecific->splineDataSizeAlloc;
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_theta, newSplineDataSize,
-                           &currentSizeTemp, &currentSizeTempAlloc, pmeGpu->archSpecific->context);
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_dtheta, newSplineDataSize,
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_theta,
+                           newSplineDataSize,
+                           &currentSizeTemp,
+                           &currentSizeTempAlloc,
+                           pmeGpu->archSpecific->deviceContext_);
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_dtheta,
+                           newSplineDataSize,
                            &pmeGpu->archSpecific->splineDataSize,
-                           &pmeGpu->archSpecific->splineDataSizeAlloc, pmeGpu->archSpecific->context);
+                           &pmeGpu->archSpecific->splineDataSizeAlloc,
+                           pmeGpu->archSpecific->deviceContext_);
     // the host side reallocation
     if (shouldRealloc)
     {
@@ -298,9 +348,11 @@ void pme_gpu_realloc_grid_indices(PmeGpu* pmeGpu)
 {
     const size_t newIndicesSize = DIM * pmeGpu->nAtomsAlloc;
     GMX_ASSERT(newIndicesSize > 0, "Bad number of atoms in PME GPU");
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_gridlineIndices, newIndicesSize,
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_gridlineIndices,
+                           newIndicesSize,
                            &pmeGpu->archSpecific->gridlineIndicesSize,
-                           &pmeGpu->archSpecific->gridlineIndicesSizeAlloc, pmeGpu->archSpecific->context);
+                           &pmeGpu->archSpecific->gridlineIndicesSizeAlloc,
+                           pmeGpu->archSpecific->deviceContext_);
     pfree(pmeGpu->staging.h_gridlineIndices);
     pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_gridlineIndices), newIndicesSize * sizeof(int));
 }
@@ -313,50 +365,69 @@ void pme_gpu_free_grid_indices(const PmeGpu* pmeGpu)
 
 void pme_gpu_realloc_grids(PmeGpu* pmeGpu)
 {
-    auto*     kernelParamsPtr = pmeGpu->kernelParams.get();
+    auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+
     const int newRealGridSize = kernelParamsPtr->grid.realGridSizePadded[XX]
                                 * kernelParamsPtr->grid.realGridSizePadded[YY]
                                 * kernelParamsPtr->grid.realGridSizePadded[ZZ];
     const int newComplexGridSize = kernelParamsPtr->grid.complexGridSizePadded[XX]
                                    * kernelParamsPtr->grid.complexGridSizePadded[YY]
                                    * kernelParamsPtr->grid.complexGridSizePadded[ZZ] * 2;
-    // Multiplied by 2 because we count complex grid size for complex numbers, but all allocations/pointers are float
-    if (pmeGpu->archSpecific->performOutOfPlaceFFT)
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
     {
-        /* 2 separate grids */
-        reallocateDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid, newComplexGridSize,
-                               &pmeGpu->archSpecific->complexGridSize,
-                               &pmeGpu->archSpecific->complexGridSizeAlloc, pmeGpu->archSpecific->context);
-        reallocateDeviceBuffer(&kernelParamsPtr->grid.d_realGrid, newRealGridSize,
-                               &pmeGpu->archSpecific->realGridSize,
-                               &pmeGpu->archSpecific->realGridSizeAlloc, pmeGpu->archSpecific->context);
-    }
-    else
-    {
-        /* A single buffer so that any grid will fit */
-        const int newGridsSize = std::max(newRealGridSize, newComplexGridSize);
-        reallocateDeviceBuffer(
-                &kernelParamsPtr->grid.d_realGrid, newGridsSize, &pmeGpu->archSpecific->realGridSize,
-                &pmeGpu->archSpecific->realGridSizeAlloc, pmeGpu->archSpecific->context);
-        kernelParamsPtr->grid.d_fourierGrid   = kernelParamsPtr->grid.d_realGrid;
-        pmeGpu->archSpecific->complexGridSize = pmeGpu->archSpecific->realGridSize;
-        // the size might get used later for copying the grid
+        // Multiplied by 2 because we count complex grid size for complex numbers, but all allocations/pointers are float
+        if (pmeGpu->archSpecific->performOutOfPlaceFFT)
+        {
+            /* 2 separate grids */
+            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid[gridIndex],
+                                   newComplexGridSize,
+                                   &pmeGpu->archSpecific->complexGridSize[gridIndex],
+                                   &pmeGpu->archSpecific->complexGridCapacity[gridIndex],
+                                   pmeGpu->archSpecific->deviceContext_);
+            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_realGrid[gridIndex],
+                                   newRealGridSize,
+                                   &pmeGpu->archSpecific->realGridSize[gridIndex],
+                                   &pmeGpu->archSpecific->realGridCapacity[gridIndex],
+                                   pmeGpu->archSpecific->deviceContext_);
+        }
+        else
+        {
+            /* A single buffer so that any grid will fit */
+            const int newGridsSize = std::max(newRealGridSize, newComplexGridSize);
+            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_realGrid[gridIndex],
+                                   newGridsSize,
+                                   &pmeGpu->archSpecific->realGridSize[gridIndex],
+                                   &pmeGpu->archSpecific->realGridCapacity[gridIndex],
+                                   pmeGpu->archSpecific->deviceContext_);
+            kernelParamsPtr->grid.d_fourierGrid[gridIndex] = kernelParamsPtr->grid.d_realGrid[gridIndex];
+            pmeGpu->archSpecific->complexGridSize[gridIndex] =
+                    pmeGpu->archSpecific->realGridSize[gridIndex];
+            // the size might get used later for copying the grid
+        }
     }
 }
 
 void pme_gpu_free_grids(const PmeGpu* pmeGpu)
 {
-    if (pmeGpu->archSpecific->performOutOfPlaceFFT)
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
     {
-        freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_fourierGrid);
+        if (pmeGpu->archSpecific->performOutOfPlaceFFT)
+        {
+            freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_fourierGrid[gridIndex]);
+        }
+        freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex]);
     }
-    freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid);
 }
 
 void pme_gpu_clear_grids(const PmeGpu* pmeGpu)
 {
-    clearDeviceBufferAsync(&pmeGpu->kernelParams->grid.d_realGrid, 0,
-                           pmeGpu->archSpecific->realGridSize, pmeGpu->archSpecific->pmeStream);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        clearDeviceBufferAsync(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex],
+                               0,
+                               pmeGpu->archSpecific->realGridSize[gridIndex],
+                               pmeGpu->archSpecific->pmeStream_);
+    }
 }
 
 void pme_gpu_realloc_and_copy_fract_shifts(PmeGpu* pmeGpu)
@@ -375,37 +446,28 @@ void pme_gpu_realloc_and_copy_fract_shifts(PmeGpu* pmeGpu)
 
     const int newFractShiftsSize = cellCount * (nx + ny + nz);
 
-#if GMX_GPU == GMX_GPU_CUDA
-    initParamLookupTable(kernelParamsPtr->grid.d_fractShiftsTable, kernelParamsPtr->fractShiftsTableTexture,
-                         pmeGpu->common->fsh.data(), newFractShiftsSize);
-
-    initParamLookupTable(kernelParamsPtr->grid.d_gridlineIndicesTable,
-                         kernelParamsPtr->gridlineIndicesTableTexture, pmeGpu->common->nn.data(),
-                         newFractShiftsSize);
-#elif GMX_GPU == GMX_GPU_OPENCL
-    // No dedicated texture routines....
-    allocateDeviceBuffer(&kernelParamsPtr->grid.d_fractShiftsTable, newFractShiftsSize,
-                         pmeGpu->archSpecific->context);
-    allocateDeviceBuffer(&kernelParamsPtr->grid.d_gridlineIndicesTable, newFractShiftsSize,
-                         pmeGpu->archSpecific->context);
-    copyToDeviceBuffer(&kernelParamsPtr->grid.d_fractShiftsTable, pmeGpu->common->fsh.data(), 0,
-                       newFractShiftsSize, pmeGpu->archSpecific->pmeStream,
-                       GpuApiCallBehavior::Async, nullptr);
-    copyToDeviceBuffer(&kernelParamsPtr->grid.d_gridlineIndicesTable, pmeGpu->common->nn.data(), 0,
-                       newFractShiftsSize, pmeGpu->archSpecific->pmeStream,
-                       GpuApiCallBehavior::Async, nullptr);
-#endif
+    initParamLookupTable(&kernelParamsPtr->grid.d_fractShiftsTable,
+                         &kernelParamsPtr->fractShiftsTableTexture,
+                         pmeGpu->common->fsh.data(),
+                         newFractShiftsSize,
+                         pmeGpu->archSpecific->deviceContext_);
+
+    initParamLookupTable(&kernelParamsPtr->grid.d_gridlineIndicesTable,
+                         &kernelParamsPtr->gridlineIndicesTableTexture,
+                         pmeGpu->common->nn.data(),
+                         newFractShiftsSize,
+                         pmeGpu->archSpecific->deviceContext_);
 }
 
 void pme_gpu_free_fract_shifts(const PmeGpu* pmeGpu)
 {
     auto* kernelParamsPtr = pmeGpu->kernelParams.get();
-#if GMX_GPU == GMX_GPU_CUDA
-    destroyParamLookupTable(kernelParamsPtr->grid.d_fractShiftsTable,
-                            kernelParamsPtr->fractShiftsTableTexture);
-    destroyParamLookupTable(kernelParamsPtr->grid.d_gridlineIndicesTable,
-                            kernelParamsPtr->gridlineIndicesTableTexture);
-#elif GMX_GPU == GMX_GPU_OPENCL
+#if GMX_GPU_CUDA
+    destroyParamLookupTable(&kernelParamsPtr->grid.d_fractShiftsTable,
+                            &kernelParamsPtr->fractShiftsTableTexture);
+    destroyParamLookupTable(&kernelParamsPtr->grid.d_gridlineIndicesTable,
+                            &kernelParamsPtr->gridlineIndicesTableTexture);
+#elif GMX_GPU_OPENCL || GMX_GPU_SYCL
     freeDeviceBuffer(&kernelParamsPtr->grid.d_fractShiftsTable);
     freeDeviceBuffer(&kernelParamsPtr->grid.d_gridlineIndicesTable);
 #endif
@@ -413,63 +475,99 @@ void pme_gpu_free_fract_shifts(const PmeGpu* pmeGpu)
 
 bool pme_gpu_stream_query(const PmeGpu* pmeGpu)
 {
-    return haveStreamTasksCompleted(pmeGpu->archSpecific->pmeStream);
+    return haveStreamTasksCompleted(pmeGpu->archSpecific->pmeStream_);
 }
 
-void pme_gpu_copy_input_gather_grid(const PmeGpu* pmeGpu, float* h_grid)
+void pme_gpu_copy_input_gather_grid(const PmeGpu* pmeGpu, const float* h_grid, const int gridIndex)
 {
-    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid, h_grid, 0, pmeGpu->archSpecific->realGridSize,
-                       pmeGpu->archSpecific->pmeStream, pmeGpu->settings.transferKind, nullptr);
+    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex],
+                       h_grid,
+                       0,
+                       pmeGpu->archSpecific->realGridSize[gridIndex],
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
-void pme_gpu_copy_output_spread_grid(const PmeGpu* pmeGpu, float* h_grid)
+void pme_gpu_copy_output_spread_grid(const PmeGpu* pmeGpu, float* h_grid, const int gridIndex)
 {
-    copyFromDeviceBuffer(h_grid, &pmeGpu->kernelParams->grid.d_realGrid, 0,
-                         pmeGpu->archSpecific->realGridSize, pmeGpu->archSpecific->pmeStream,
-                         pmeGpu->settings.transferKind, nullptr);
-    pmeGpu->archSpecific->syncSpreadGridD2H.markEvent(pmeGpu->archSpecific->pmeStream);
+    copyFromDeviceBuffer(h_grid,
+                         &pmeGpu->kernelParams->grid.d_realGrid[gridIndex],
+                         0,
+                         pmeGpu->archSpecific->realGridSize[gridIndex],
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
+    pmeGpu->archSpecific->syncSpreadGridD2H.markEvent(pmeGpu->archSpecific->pmeStream_);
 }
 
 void pme_gpu_copy_output_spread_atom_data(const PmeGpu* pmeGpu)
 {
-    const int    alignment       = pme_gpu_get_atoms_per_warp(pmeGpu);
-    const size_t nAtomsPadded    = ((pmeGpu->nAtomsAlloc + alignment - 1) / alignment) * alignment;
-    const size_t splinesCount    = DIM * nAtomsPadded * pmeGpu->common->pme_order;
+    const size_t splinesCount    = DIM * pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order;
     auto*        kernelParamsPtr = pmeGpu->kernelParams.get();
-    copyFromDeviceBuffer(pmeGpu->staging.h_dtheta, &kernelParamsPtr->atoms.d_dtheta, 0, splinesCount,
-                         pmeGpu->archSpecific->pmeStream, pmeGpu->settings.transferKind, nullptr);
-    copyFromDeviceBuffer(pmeGpu->staging.h_theta, &kernelParamsPtr->atoms.d_theta, 0, splinesCount,
-                         pmeGpu->archSpecific->pmeStream, pmeGpu->settings.transferKind, nullptr);
-    copyFromDeviceBuffer(pmeGpu->staging.h_gridlineIndices, &kernelParamsPtr->atoms.d_gridlineIndices,
-                         0, kernelParamsPtr->atoms.nAtoms * DIM, pmeGpu->archSpecific->pmeStream,
-                         pmeGpu->settings.transferKind, nullptr);
+    copyFromDeviceBuffer(pmeGpu->staging.h_dtheta,
+                         &kernelParamsPtr->atoms.d_dtheta,
+                         0,
+                         splinesCount,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
+    copyFromDeviceBuffer(pmeGpu->staging.h_theta,
+                         &kernelParamsPtr->atoms.d_theta,
+                         0,
+                         splinesCount,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
+    copyFromDeviceBuffer(pmeGpu->staging.h_gridlineIndices,
+                         &kernelParamsPtr->atoms.d_gridlineIndices,
+                         0,
+                         kernelParamsPtr->atoms.nAtoms * DIM,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
 }
 
 void pme_gpu_copy_input_gather_atom_data(const PmeGpu* pmeGpu)
 {
-    const int    alignment       = pme_gpu_get_atoms_per_warp(pmeGpu);
-    const size_t nAtomsPadded    = ((pmeGpu->nAtomsAlloc + alignment - 1) / alignment) * alignment;
-    const size_t splinesCount    = DIM * nAtomsPadded * pmeGpu->common->pme_order;
+    const size_t splinesCount    = DIM * pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order;
     auto*        kernelParamsPtr = pmeGpu->kernelParams.get();
-    if (c_usePadding)
-    {
-        // TODO: could clear only the padding and not the whole thing, but this is a test-exclusive code anyway
-        clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_gridlineIndices, 0,
-                               pmeGpu->nAtomsAlloc * DIM, pmeGpu->archSpecific->pmeStream);
-        clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_dtheta, 0,
-                               pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order * DIM,
-                               pmeGpu->archSpecific->pmeStream);
-        clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_theta, 0,
-                               pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order * DIM,
-                               pmeGpu->archSpecific->pmeStream);
-    }
-    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_dtheta, pmeGpu->staging.h_dtheta, 0, splinesCount,
-                       pmeGpu->archSpecific->pmeStream, pmeGpu->settings.transferKind, nullptr);
-    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_theta, pmeGpu->staging.h_theta, 0, splinesCount,
-                       pmeGpu->archSpecific->pmeStream, pmeGpu->settings.transferKind, nullptr);
-    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_gridlineIndices, pmeGpu->staging.h_gridlineIndices,
-                       0, kernelParamsPtr->atoms.nAtoms * DIM, pmeGpu->archSpecific->pmeStream,
-                       pmeGpu->settings.transferKind, nullptr);
+
+    // TODO: could clear only the padding and not the whole thing, but this is a test-exclusive code anyway
+    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_gridlineIndices,
+                           0,
+                           pmeGpu->nAtomsAlloc * DIM,
+                           pmeGpu->archSpecific->pmeStream_);
+    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_dtheta,
+                           0,
+                           pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order * DIM,
+                           pmeGpu->archSpecific->pmeStream_);
+    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_theta,
+                           0,
+                           pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order * DIM,
+                           pmeGpu->archSpecific->pmeStream_);
+
+    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_dtheta,
+                       pmeGpu->staging.h_dtheta,
+                       0,
+                       splinesCount,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
+    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_theta,
+                       pmeGpu->staging.h_theta,
+                       0,
+                       splinesCount,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
+    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_gridlineIndices,
+                       pmeGpu->staging.h_gridlineIndices,
+                       0,
+                       kernelParamsPtr->atoms.nAtoms * DIM,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
 void pme_gpu_sync_spread_grid(const PmeGpu* pmeGpu)
@@ -477,16 +575,16 @@ void pme_gpu_sync_spread_grid(const PmeGpu* pmeGpu)
     pmeGpu->archSpecific->syncSpreadGridD2H.waitForEvent();
 }
 
-void pme_gpu_init_internal(PmeGpu* pmeGpu)
+/*! \brief Internal GPU initialization for PME.
+ *
+ * \param[in]  pmeGpu         GPU PME data.
+ * \param[in]  deviceContext  GPU context.
+ * \param[in]  deviceStream   GPU stream.
+ */
+static void pme_gpu_init_internal(PmeGpu* pmeGpu, const DeviceContext& deviceContext, const DeviceStream& deviceStream)
 {
-#if GMX_GPU == GMX_GPU_CUDA
-    // Prepare to use the device that this PME task was assigned earlier.
-    // Other entities, such as CUDA timing events, are known to implicitly use the device context.
-    CU_RET_ERR(cudaSetDevice(pmeGpu->deviceInfo->id), "Switching to PME CUDA device");
-#endif
-
     /* Allocate the target-specific structures */
-    pmeGpu->archSpecific.reset(new PmeGpuSpecific());
+    pmeGpu->archSpecific.reset(new PmeGpuSpecific(deviceContext, deviceStream));
     pmeGpu->kernelParams.reset(new PmeGpuKernelParams());
 
     pmeGpu->archSpecific->performOutOfPlaceFFT = true;
@@ -495,69 +593,15 @@ void pme_gpu_init_internal(PmeGpu* pmeGpu)
      * TODO: PME could also try to pick up nice grid sizes (with factors of 2, 3, 5, 7).
      */
 
-    // TODO: this is just a convenient reuse because programHandle_ currently is in charge of creating context
-    pmeGpu->archSpecific->context = pmeGpu->programHandle_->impl_->context;
-
-    // timing enabling - TODO put this in gpu_utils (even though generally this is just option handling?) and reuse in NB
-    if (GMX_GPU == GMX_GPU_CUDA)
-    {
-        /* WARNING: CUDA timings are incorrect with multiple streams.
-         *          This is the main reason why they are disabled by default.
-         */
-        // TODO: Consider turning on by default when we can detect nr of streams.
-        pmeGpu->archSpecific->useTiming = (getenv("GMX_ENABLE_GPU_TIMING") != nullptr);
-    }
-    else if (GMX_GPU == GMX_GPU_OPENCL)
-    {
-        pmeGpu->archSpecific->useTiming = (getenv("GMX_DISABLE_GPU_TIMING") == nullptr);
-    }
-
-#if GMX_GPU == GMX_GPU_CUDA
-    pmeGpu->maxGridWidthX = pmeGpu->deviceInfo->prop.maxGridSize[0];
-#elif GMX_GPU == GMX_GPU_OPENCL
-    pmeGpu->maxGridWidthX = INT32_MAX / 2;
+#if GMX_GPU_CUDA
+    pmeGpu->kernelParams->usePipeline       = char(false);
+    pmeGpu->kernelParams->pipelineAtomStart = 0;
+    pmeGpu->kernelParams->pipelineAtomEnd   = 0;
+    pmeGpu->maxGridWidthX                   = deviceContext.deviceInfo().prop.maxGridSize[0];
+#else
+    // Use this path for any non-CUDA GPU acceleration
     // TODO: is there no really global work size limit in OpenCL?
-#endif
-
-    /* Creating a PME GPU stream:
-     * - default high priority with CUDA
-     * - no priorities implemented yet with OpenCL; see #2532
-     */
-#if GMX_GPU == GMX_GPU_CUDA
-    cudaError_t stat;
-    int         highest_priority, lowest_priority;
-    stat = cudaDeviceGetStreamPriorityRange(&lowest_priority, &highest_priority);
-    CU_RET_ERR(stat, "PME cudaDeviceGetStreamPriorityRange failed");
-    stat = cudaStreamCreateWithPriority(&pmeGpu->archSpecific->pmeStream,
-                                        cudaStreamDefault, // cudaStreamNonBlocking,
-                                        highest_priority);
-    CU_RET_ERR(stat, "cudaStreamCreateWithPriority on the PME stream failed");
-#elif GMX_GPU == GMX_GPU_OPENCL
-    cl_command_queue_properties queueProperties =
-            pmeGpu->archSpecific->useTiming ? CL_QUEUE_PROFILING_ENABLE : 0;
-    cl_device_id device_id = pmeGpu->deviceInfo->oclDeviceId;
-    cl_int       clError;
-    pmeGpu->archSpecific->pmeStream =
-            clCreateCommandQueue(pmeGpu->archSpecific->context, device_id, queueProperties, &clError);
-    if (clError != CL_SUCCESS)
-    {
-        GMX_THROW(gmx::InternalError("Failed to create PME command queue"));
-    }
-#endif
-}
-
-void pme_gpu_destroy_specific(const PmeGpu* pmeGpu)
-{
-#if GMX_GPU == GMX_GPU_CUDA
-    /* Destroy the CUDA stream */
-    cudaError_t stat = cudaStreamDestroy(pmeGpu->archSpecific->pmeStream);
-    CU_RET_ERR(stat, "PME cudaStreamDestroy error");
-#elif GMX_GPU == GMX_GPU_OPENCL
-    cl_int clError = clReleaseCommandQueue(pmeGpu->archSpecific->pmeStream);
-    if (clError != CL_SUCCESS)
-    {
-        gmx_warning("Failed to destroy PME command queue");
-    }
+    pmeGpu->maxGridWidthX = INT32_MAX / 2;
 #endif
 }
 
@@ -566,9 +610,46 @@ void pme_gpu_reinit_3dfft(const PmeGpu* pmeGpu)
     if (pme_gpu_settings(pmeGpu).performGPUFFT)
     {
         pmeGpu->archSpecific->fftSetup.resize(0);
-        for (int i = 0; i < pmeGpu->common->ngrids; i++)
+        const bool         performOutOfPlaceFFT      = pmeGpu->archSpecific->performOutOfPlaceFFT;
+        const bool         allocateGrid              = false;
+        MPI_Comm           comm                      = MPI_COMM_NULL;
+        std::array<int, 1> gridOffsetsInXForEachRank = { 0 };
+        std::array<int, 1> gridOffsetsInYForEachRank = { 0 };
+#if GMX_GPU_CUDA
+        const gmx::FftBackend backend = gmx::FftBackend::Cufft;
+#elif GMX_GPU_OPENCL
+        const gmx::FftBackend backend = gmx::FftBackend::Ocl;
+#elif GMX_GPU_SYCL
+#    if GMX_SYCL_DPCPP && GMX_FFT_MKL
+        const gmx::FftBackend backend = gmx::FftBackend::SyclMkl;
+#    elif GMX_SYCL_HIPSYCL
+        const gmx::FftBackend backend = gmx::FftBackend::SyclRocfft;
+#    else
+        const gmx::FftBackend backend = gmx::FftBackend::Sycl;
+#    endif
+#else
+        GMX_RELEASE_ASSERT(false, "Unknown GPU backend");
+        const gmx::FftBackend backend = gmx::FftBackend::Count;
+#endif
+
+        PmeGpuGridParams& grid = pme_gpu_get_kernel_params_base_ptr(pmeGpu)->grid;
+        for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
         {
-            pmeGpu->archSpecific->fftSetup.push_back(std::make_unique<GpuParallel3dFft>(pmeGpu));
+            pmeGpu->archSpecific->fftSetup.push_back(
+                    std::make_unique<gmx::Gpu3dFft>(backend,
+                                                    allocateGrid,
+                                                    comm,
+                                                    gridOffsetsInXForEachRank,
+                                                    gridOffsetsInYForEachRank,
+                                                    grid.realGridSize[ZZ],
+                                                    performOutOfPlaceFFT,
+                                                    pmeGpu->archSpecific->deviceContext_,
+                                                    pmeGpu->archSpecific->pmeStream_,
+                                                    grid.realGridSize,
+                                                    grid.realGridSizePadded,
+                                                    grid.complexGridSizePadded,
+                                                    &(grid.d_realGrid[gridIndex]),
+                                                    &(grid.d_fourierGrid[gridIndex])));
         }
     }
 }
@@ -578,69 +659,62 @@ void pme_gpu_destroy_3dfft(const PmeGpu* pmeGpu)
     pmeGpu->archSpecific->fftSetup.resize(0);
 }
 
-int getSplineParamFullIndex(int order, int splineIndex, int dimIndex, int atomIndex, int atomsPerWarp)
+void pme_gpu_getEnergyAndVirial(const gmx_pme_t& pme, const float lambda, PmeOutput* output)
 {
-    if (order != c_pmeGpuOrder)
+    const PmeGpu* pmeGpu = pme.gpu;
+
+    GMX_ASSERT(lambda == 1.0 || pmeGpu->common->ngrids == 2,
+               "Invalid combination of lambda and number of grids");
+
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
     {
-        throw order;
+        for (int j = 0; j < c_virialAndEnergyCount; j++)
+        {
+            GMX_ASSERT(std::isfinite(pmeGpu->staging.h_virialAndEnergy[gridIndex][j]),
+                       "PME GPU produces incorrect energy/virial.");
+        }
     }
-    constexpr int fixedOrder = c_pmeGpuOrder;
-    GMX_UNUSED_VALUE(fixedOrder);
-
-    const int atomWarpIndex = atomIndex % atomsPerWarp;
-    const int warpIndex     = atomIndex / atomsPerWarp;
-    int       indexBase, result;
-    switch (atomsPerWarp)
+    for (int dim1 = 0; dim1 < DIM; dim1++)
     {
-        case 1:
-            indexBase = getSplineParamIndexBase<fixedOrder, 1>(warpIndex, atomWarpIndex);
-            result    = getSplineParamIndex<fixedOrder, 1>(indexBase, dimIndex, splineIndex);
-            break;
-
-        case 2:
-            indexBase = getSplineParamIndexBase<fixedOrder, 2>(warpIndex, atomWarpIndex);
-            result    = getSplineParamIndex<fixedOrder, 2>(indexBase, dimIndex, splineIndex);
-            break;
-
-        case 4:
-            indexBase = getSplineParamIndexBase<fixedOrder, 4>(warpIndex, atomWarpIndex);
-            result    = getSplineParamIndex<fixedOrder, 4>(indexBase, dimIndex, splineIndex);
-            break;
-
-        case 8:
-            indexBase = getSplineParamIndexBase<fixedOrder, 8>(warpIndex, atomWarpIndex);
-            result    = getSplineParamIndex<fixedOrder, 8>(indexBase, dimIndex, splineIndex);
-            break;
-
-        default:
-            GMX_THROW(gmx::NotImplementedError(
-                    gmx::formatString("Test function call not unrolled for atomsPerWarp = %d in "
-                                      "getSplineParamFullIndex",
-                                      atomsPerWarp)));
+        for (int dim2 = 0; dim2 < DIM; dim2++)
+        {
+            output->coulombVirial_[dim1][dim2] = 0;
+        }
+    }
+    output->coulombEnergy_ = 0;
+    float scale            = 1.0;
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        if (pmeGpu->common->ngrids == 2)
+        {
+            scale = gridIndex == 0 ? (1.0 - lambda) : lambda;
+        }
+        output->coulombVirial_[XX][XX] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][0];
+        output->coulombVirial_[YY][YY] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][1];
+        output->coulombVirial_[ZZ][ZZ] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][2];
+        output->coulombVirial_[XX][YY] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][3];
+        output->coulombVirial_[YY][XX] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][3];
+        output->coulombVirial_[XX][ZZ] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][4];
+        output->coulombVirial_[ZZ][XX] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][4];
+        output->coulombVirial_[YY][ZZ] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][5];
+        output->coulombVirial_[ZZ][YY] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][5];
+        output->coulombEnergy_ += scale * 0.5F * pmeGpu->staging.h_virialAndEnergy[gridIndex][6];
+    }
+    if (pmeGpu->common->ngrids > 1)
+    {
+        output->coulombDvdl_ = 0.5F
+                               * (pmeGpu->staging.h_virialAndEnergy[FEP_STATE_B][6]
+                                  - pmeGpu->staging.h_virialAndEnergy[FEP_STATE_A][6]);
     }
-    return result;
-}
-
-void pme_gpu_getEnergyAndVirial(const gmx_pme_t& pme, PmeOutput* output)
-{
-    const PmeGpu* pmeGpu = pme.gpu;
-    for (int j = 0; j < c_virialAndEnergyCount; j++)
-    {
-        GMX_ASSERT(std::isfinite(pmeGpu->staging.h_virialAndEnergy[j]),
-                   "PME GPU produces incorrect energy/virial.");
-    }
-
-    unsigned int j                 = 0;
-    output->coulombVirial_[XX][XX] = 0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[YY][YY] = 0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[ZZ][ZZ] = 0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[XX][YY] = output->coulombVirial_[YY][XX] =
-            0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[XX][ZZ] = output->coulombVirial_[ZZ][XX] =
-            0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[YY][ZZ] = output->coulombVirial_[ZZ][YY] =
-            0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombEnergy_ = 0.5F * pmeGpu->staging.h_virialAndEnergy[j++];
 }
 
 /*! \brief Sets the force-related members in \p output
@@ -657,22 +731,19 @@ static void pme_gpu_getForceOutput(PmeGpu* pmeGpu, PmeOutput* output)
     }
 }
 
-PmeOutput pme_gpu_getOutput(const gmx_pme_t& pme, const int flags)
+PmeOutput pme_gpu_getOutput(const gmx_pme_t& pme, const bool computeEnergyAndVirial, const real lambdaQ)
 {
-    PmeGpu*    pmeGpu                      = pme.gpu;
-    const bool haveComputedEnergyAndVirial = (flags & GMX_PME_CALC_ENER_VIR) != 0;
+    PmeGpu* pmeGpu = pme.gpu;
 
     PmeOutput output;
 
     pme_gpu_getForceOutput(pmeGpu, &output);
 
-    // The caller knows from the flags that the energy and the virial are not usable
-    // on the else branch
-    if (haveComputedEnergyAndVirial)
+    if (computeEnergyAndVirial)
     {
         if (pme_gpu_settings(pmeGpu).performGPUSolve)
         {
-            pme_gpu_getEnergyAndVirial(pme, &output);
+            pme_gpu_getEnergyAndVirial(pme, lambdaQ, &output);
         }
         else
         {
@@ -714,9 +785,13 @@ void pme_gpu_update_input_box(PmeGpu gmx_unused* pmeGpu, const matrix gmx_unused
 static void pme_gpu_reinit_grids(PmeGpu* pmeGpu)
 {
     auto* kernelParamsPtr = pme_gpu_get_kernel_params_base_ptr(pmeGpu);
+
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+
     kernelParamsPtr->grid.ewaldFactor =
             (M_PI * M_PI) / (pmeGpu->common->ewaldcoeff_q * pmeGpu->common->ewaldcoeff_q);
-
     /* The grid size variants */
     for (int i = 0; i < DIM; i++)
     {
@@ -737,14 +812,17 @@ static void pme_gpu_reinit_grids(PmeGpu* pmeGpu)
         kernelParamsPtr->grid.realGridSizePadded[ZZ] =
                 (kernelParamsPtr->grid.realGridSize[ZZ] / 2 + 1) * 2;
     }
-
     /* GPU FFT: n real elements correspond to (n / 2 + 1) complex elements in minor dimension */
     kernelParamsPtr->grid.complexGridSize[ZZ] /= 2;
     kernelParamsPtr->grid.complexGridSize[ZZ]++;
     kernelParamsPtr->grid.complexGridSizePadded[ZZ] = kernelParamsPtr->grid.complexGridSize[ZZ];
 
     pme_gpu_realloc_and_copy_fract_shifts(pmeGpu);
-    pme_gpu_realloc_and_copy_bspline_values(pmeGpu);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        pme_gpu_realloc_and_copy_bspline_values(pmeGpu, gridIndex);
+    }
+
     pme_gpu_realloc_grids(pmeGpu);
     pme_gpu_reinit_3dfft(pmeGpu);
 }
@@ -765,7 +843,7 @@ static void pme_gpu_copy_common_data_from(const gmx_pme_t* pme)
     /* TODO: Consider refactoring the CPU PME code to use the same structure,
      * so that this function becomes 2 lines */
     PmeGpu* pmeGpu               = pme->gpu;
-    pmeGpu->common->ngrids       = pme->ngrids;
+    pmeGpu->common->ngrids       = pme->bFEP_q ? 2 : 1;
     pmeGpu->common->epsilon_r    = pme->epsilon_r;
     pmeGpu->common->ewaldcoeff_q = pme->ewaldcoeff_q;
     pmeGpu->common->nk[XX]       = pme->nkx;
@@ -791,7 +869,7 @@ static void pme_gpu_copy_common_data_from(const gmx_pme_t* pme)
     pmeGpu->common->nn.insert(pmeGpu->common->nn.end(), pme->nnz, pme->nnz + cellCount * pme->nkz);
     pmeGpu->common->runMode       = pme->runMode;
     pmeGpu->common->isRankPmeOnly = !pme->bPPnode;
-    pmeGpu->common->boxScaler     = pme->boxScaler;
+    pmeGpu->common->boxScaler     = pme->boxScaler.get();
 }
 
 /*! \libinternal \brief
@@ -801,15 +879,15 @@ static void pme_gpu_copy_common_data_from(const gmx_pme_t* pme)
  */
 static void pme_gpu_select_best_performing_pme_spreadgather_kernels(PmeGpu* pmeGpu)
 {
-    if (pmeGpu->kernelParams->atoms.nAtoms > c_pmeGpuPerformanceAtomLimit && (GMX_GPU == GMX_GPU_CUDA))
+    if (GMX_GPU_CUDA && pmeGpu->kernelParams->atoms.nAtoms > c_pmeGpuPerformanceAtomLimit)
     {
-        pmeGpu->settings.useOrderThreadsPerAtom = true;
-        pmeGpu->settings.recalculateSplines     = true;
+        pmeGpu->settings.threadsPerAtom     = ThreadsPerAtom::Order;
+        pmeGpu->settings.recalculateSplines = true;
     }
     else
     {
-        pmeGpu->settings.useOrderThreadsPerAtom = false;
-        pmeGpu->settings.recalculateSplines     = false;
+        pmeGpu->settings.threadsPerAtom     = ThreadsPerAtom::OrderSquared;
+        pmeGpu->settings.recalculateSplines = false;
     }
 }
 
@@ -819,10 +897,14 @@ static void pme_gpu_select_best_performing_pme_spreadgather_kernels(PmeGpu* pmeG
  * TODO: this should become PmeGpu::PmeGpu()
  *
  * \param[in,out] pme            The PME structure.
- * \param[in,out] deviceInfo     The GPU device information structure.
- * \param[in]     pmeGpuProgram  The handle to the program/kernel data created outside (e.g. in unit tests/runner)
+ * \param[in]     deviceContext  The GPU context.
+ * \param[in]     deviceStream   The GPU stream.
+ * \param[in,out] pmeGpuProgram  The handle to the program/kernel data created outside (e.g. in unit tests/runner)
  */
-static void pme_gpu_init(gmx_pme_t* pme, const DeviceInformation* deviceInfo, const PmeGpuProgram* pmeGpuProgram)
+static void pme_gpu_init(gmx_pme_t*           pme,
+                         const DeviceContext& deviceContext,
+                         const DeviceStream&  deviceStream,
+                         const PmeGpuProgram* pmeGpuProgram)
 {
     pme->gpu       = new PmeGpu();
     PmeGpu* pmeGpu = pme->gpu;
@@ -839,82 +921,20 @@ static void pme_gpu_init(gmx_pme_t* pme, const DeviceInformation* deviceInfo, co
 
     pme_gpu_set_testing(pmeGpu, false);
 
-    pmeGpu->deviceInfo = deviceInfo;
     GMX_ASSERT(pmeGpuProgram != nullptr, "GPU kernels must be already compiled");
     pmeGpu->programHandle_ = pmeGpuProgram;
 
     pmeGpu->initializedClfftLibrary_ = std::make_unique<gmx::ClfftInitializer>();
 
-    pme_gpu_init_internal(pmeGpu);
-    pme_gpu_alloc_energy_virial(pmeGpu);
+    pme_gpu_init_internal(pmeGpu, deviceContext, deviceStream);
 
     pme_gpu_copy_common_data_from(pme);
+    pme_gpu_alloc_energy_virial(pmeGpu);
 
     GMX_ASSERT(pmeGpu->common->epsilon_r != 0.0F, "PME GPU: bad electrostatic coefficient");
 
     auto* kernelParamsPtr               = pme_gpu_get_kernel_params_base_ptr(pmeGpu);
-    kernelParamsPtr->constants.elFactor = ONE_4PI_EPS0 / pmeGpu->common->epsilon_r;
-}
-
-void pme_gpu_transform_spline_atom_data(const PmeGpu*      pmeGpu,
-                                        const PmeAtomComm* atc,
-                                        PmeSplineDataType  type,
-                                        int                dimIndex,
-                                        PmeLayoutTransform transform)
-{
-    // The GPU atom spline data is laid out in a different way currently than the CPU one.
-    // This function converts the data from GPU to CPU layout (in the host memory).
-    // It is only intended for testing purposes so far.
-    // Ideally we should use similar layouts on CPU and GPU if we care about mixed modes and their
-    // performance (e.g. spreading on GPU, gathering on CPU).
-    GMX_RELEASE_ASSERT(atc->nthread == 1, "Only the serial PME data layout is supported");
-    const uintmax_t threadIndex  = 0;
-    const auto      atomCount    = pme_gpu_get_kernel_params_base_ptr(pmeGpu)->atoms.nAtoms;
-    const auto      atomsPerWarp = pme_gpu_get_atoms_per_warp(pmeGpu);
-    const auto      pmeOrder     = pmeGpu->common->pme_order;
-    GMX_ASSERT(pmeOrder == c_pmeGpuOrder, "Only PME order 4 is implemented");
-
-    real*  cpuSplineBuffer;
-    float* h_splineBuffer;
-    switch (type)
-    {
-        case PmeSplineDataType::Values:
-            cpuSplineBuffer = atc->spline[threadIndex].theta.coefficients[dimIndex];
-            h_splineBuffer  = pmeGpu->staging.h_theta;
-            break;
-
-        case PmeSplineDataType::Derivatives:
-            cpuSplineBuffer = atc->spline[threadIndex].dtheta.coefficients[dimIndex];
-            h_splineBuffer  = pmeGpu->staging.h_dtheta;
-            break;
-
-        default: GMX_THROW(gmx::InternalError("Unknown spline data type"));
-    }
-
-    for (auto atomIndex = 0; atomIndex < atomCount; atomIndex++)
-    {
-        for (auto orderIndex = 0; orderIndex < pmeOrder; orderIndex++)
-        {
-            const auto gpuValueIndex =
-                    getSplineParamFullIndex(pmeOrder, orderIndex, dimIndex, atomIndex, atomsPerWarp);
-            const auto cpuValueIndex = atomIndex * pmeOrder + orderIndex;
-            GMX_ASSERT(cpuValueIndex < atomCount * pmeOrder,
-                       "Atom spline data index out of bounds (while transforming GPU data layout "
-                       "for host)");
-            switch (transform)
-            {
-                case PmeLayoutTransform::GpuToHost:
-                    cpuSplineBuffer[cpuValueIndex] = h_splineBuffer[gpuValueIndex];
-                    break;
-
-                case PmeLayoutTransform::HostToGpu:
-                    h_splineBuffer[gpuValueIndex] = cpuSplineBuffer[cpuValueIndex];
-                    break;
-
-                default: GMX_THROW(gmx::InternalError("Unknown layout transform"));
-            }
-        }
-    }
+    kernelParamsPtr->constants.elFactor = gmx::c_one4PiEps0 / pmeGpu->common->epsilon_r;
 }
 
 void pme_gpu_get_real_grid_sizes(const PmeGpu* pmeGpu, gmx::IVec* gridSize, gmx::IVec* paddedGridSize)
@@ -930,19 +950,21 @@ void pme_gpu_get_real_grid_sizes(const PmeGpu* pmeGpu, gmx::IVec* gridSize, gmx:
     }
 }
 
-void pme_gpu_reinit(gmx_pme_t* pme, const DeviceInformation* deviceInfo, const PmeGpuProgram* pmeGpuProgram)
+void pme_gpu_reinit(gmx_pme_t*           pme,
+                    const DeviceContext* deviceContext,
+                    const DeviceStream*  deviceStream,
+                    const PmeGpuProgram* pmeGpuProgram)
 {
     GMX_ASSERT(pme != nullptr, "Need valid PME object");
-    if (pme->runMode == PmeRunMode::CPU)
-    {
-        GMX_ASSERT(pme->gpu == nullptr, "Should not have PME GPU object");
-        return;
-    }
 
     if (!pme->gpu)
     {
+        GMX_RELEASE_ASSERT(deviceContext != nullptr,
+                           "Device context can not be nullptr when setting up PME on GPU.");
+        GMX_RELEASE_ASSERT(deviceStream != nullptr,
+                           "Device stream can not be nullptr when setting up PME on GPU.");
         /* First-time initialization */
-        pme_gpu_init(pme, deviceInfo, pmeGpuProgram);
+        pme_gpu_init(pme, *deviceContext, *deviceStream, pmeGpuProgram);
     }
     else
     {
@@ -965,7 +987,6 @@ void pme_gpu_reinit(gmx_pme_t* pme, const DeviceInformation* deviceInfo, const P
      * update for mixed mode on grid switch. TODO: use shared recipbox field.
      */
     std::memset(pme->gpu->common->previousBox, 0, sizeof(pme->gpu->common->previousBox));
-    pme_gpu_select_best_performing_pme_spreadgather_kernels(pme->gpu);
 }
 
 void pme_gpu_destroy(PmeGpu* pmeGpu)
@@ -982,29 +1003,41 @@ void pme_gpu_destroy(PmeGpu* pmeGpu)
 
     pme_gpu_destroy_3dfft(pmeGpu);
 
-    /* Free the GPU-framework specific data last */
-    pme_gpu_destroy_specific(pmeGpu);
-
     delete pmeGpu;
 }
 
-void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* charges)
+void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* chargesA, const real* chargesB)
 {
     auto* kernelParamsPtr         = pme_gpu_get_kernel_params_base_ptr(pmeGpu);
     kernelParamsPtr->atoms.nAtoms = nAtoms;
-    const int alignment           = pme_gpu_get_atom_data_alignment(pmeGpu);
-    pmeGpu->nAtomsPadded          = ((nAtoms + alignment - 1) / alignment) * alignment;
-    const int  nAtomsAlloc        = c_usePadding ? pmeGpu->nAtomsPadded : nAtoms;
-    const bool haveToRealloc =
-            (pmeGpu->nAtomsAlloc < nAtomsAlloc); /* This check might be redundant, but is logical */
-    pmeGpu->nAtomsAlloc = nAtomsAlloc;
+    const int  block_size         = pme_gpu_get_atom_data_block_size();
+    const int  nAtomsNewPadded    = ((nAtoms + block_size - 1) / block_size) * block_size;
+    const bool haveToRealloc      = (pmeGpu->nAtomsAlloc < nAtomsNewPadded);
+    pmeGpu->nAtomsAlloc           = nAtomsNewPadded;
 
 #if GMX_DOUBLE
     GMX_RELEASE_ASSERT(false, "Only single precision supported");
     GMX_UNUSED_VALUE(charges);
 #else
-    pme_gpu_realloc_and_copy_input_coefficients(pmeGpu, reinterpret_cast<const float*>(charges));
+    int gridIndex = 0;
     /* Could also be checked for haveToRealloc, but the copy always needs to be performed */
+    pme_gpu_realloc_and_copy_input_coefficients(pmeGpu, reinterpret_cast<const float*>(chargesA), gridIndex);
+    gridIndex++;
+    if (chargesB != nullptr)
+    {
+        pme_gpu_realloc_and_copy_input_coefficients(
+                pmeGpu, reinterpret_cast<const float*>(chargesB), gridIndex);
+    }
+    else
+    {
+        /* Fill the second set of coefficients with chargesA as well to be able to avoid
+         * conditionals in the GPU kernels */
+        /* FIXME: This should be avoided by making a separate templated version of the
+         * relevant kernel(s) (probably only pme_gather_kernel). That would require a
+         * reduction of the current number of templated parameters of that kernel. */
+        pme_gpu_realloc_and_copy_input_coefficients(
+                pmeGpu, reinterpret_cast<const float*>(chargesA), gridIndex);
+    }
 #endif
 
     if (haveToRealloc)
@@ -1013,6 +1046,7 @@ void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* charges)
         pme_gpu_realloc_spline_data(pmeGpu);
         pme_gpu_realloc_grid_indices(pmeGpu);
     }
+    pme_gpu_select_best_performing_pme_spreadgather_kernels(pmeGpu);
 }
 
 /*! \internal \brief
@@ -1020,23 +1054,23 @@ void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* charges)
  * In CUDA result can be nullptr stub, per GpuRegionTimer implementation.
  *
  * \param[in] pmeGpu         The PME GPU data structure.
- * \param[in] PMEStageId     The PME GPU stage gtPME_ index from the enum in src/gromacs/timing/gpu_timing.h
+ * \param[in] pmeStageId     The PME GPU stage gtPME_ index from the enum in src/gromacs/timing/gpu_timing.h
  */
-static CommandEvent* pme_gpu_fetch_timing_event(const PmeGpu* pmeGpu, size_t PMEStageId)
+static CommandEvent* pme_gpu_fetch_timing_event(const PmeGpu* pmeGpu, PmeStage pmeStageId)
 {
     CommandEvent* timingEvent = nullptr;
     if (pme_gpu_timings_enabled(pmeGpu))
     {
-        GMX_ASSERT(PMEStageId < pmeGpu->archSpecific->timingEvents.size(),
-                   "Wrong PME GPU timing event index");
-        timingEvent = pmeGpu->archSpecific->timingEvents[PMEStageId].fetchNextEvent();
+        GMX_ASSERT(pmeStageId < PmeStage::Count, "Wrong PME GPU timing event index");
+        timingEvent = pmeGpu->archSpecific->timingEvents[pmeStageId].fetchNextEvent();
     }
     return timingEvent;
 }
 
-void pme_gpu_3dfft(const PmeGpu* pmeGpu, gmx_fft_direction dir, int grid_index)
+void pme_gpu_3dfft(const PmeGpu* pmeGpu, gmx_fft_direction dir, const int grid_index)
 {
-    int timerId = (dir == GMX_FFT_REAL_TO_COMPLEX) ? gtPME_FFT_R2C : gtPME_FFT_C2R;
+    PmeStage timerId = (dir == GMX_FFT_REAL_TO_COMPLEX) ? PmeStage::FftTransformR2C
+                                                        : PmeStage::FftTransformC2R;
 
     pme_gpu_start_timing(pmeGpu, timerId);
     pmeGpu->archSpecific->fftSetup[grid_index]->perform3dFft(
@@ -1064,34 +1098,66 @@ std::pair<int, int> inline pmeGpuCreateGrid(const PmeGpu* pmeGpu, int blockCount
  * Returns a pointer to appropriate spline and spread kernel based on the input bool values
  *
  * \param[in]  pmeGpu                   The PME GPU structure.
- * \param[in]  useOrderThreadsPerAtom   bool controlling if we should use order or order*order threads per atom
+ * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  writeSplinesToGlobal     bool controlling if we should write spline data to global memory
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
-static auto selectSplineAndSpreadKernelPtr(const PmeGpu* pmeGpu, bool useOrderThreadsPerAtom, bool writeSplinesToGlobal)
+static auto selectSplineAndSpreadKernelPtr(const PmeGpu*  pmeGpu,
+                                           ThreadsPerAtom threadsPerAtom,
+                                           bool           writeSplinesToGlobal,
+                                           const int      numGrids)
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (writeSplinesToGlobal)
     {
-        if (useOrderThreadsPerAtom)
+        if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplines;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesSingle;
+            }
         }
     }
     else
     {
-        if (useOrderThreadsPerAtom)
+        if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelSingle;
+            }
         }
     }
 
@@ -1102,25 +1168,43 @@ static auto selectSplineAndSpreadKernelPtr(const PmeGpu* pmeGpu, bool useOrderTh
  * Returns a pointer to appropriate spline kernel based on the input bool values
  *
  * \param[in]  pmeGpu                   The PME GPU structure.
- * \param[in]  useOrderThreadsPerAtom   bool controlling if we should use order or order*order threads per atom
+ * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  writeSplinesToGlobal     bool controlling if we should write spline data to global memory
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
-static auto selectSplineKernelPtr(const PmeGpu* pmeGpu, bool useOrderThreadsPerAtom, bool gmx_unused writeSplinesToGlobal)
+static auto selectSplineKernelPtr(const PmeGpu*   pmeGpu,
+                                  ThreadsPerAtom  threadsPerAtom,
+                                  bool gmx_unused writeSplinesToGlobal,
+                                  const int       numGrids)
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     GMX_ASSERT(
             writeSplinesToGlobal,
             "Spline data should always be written to global memory when just calculating splines");
 
-    if (useOrderThreadsPerAtom)
+    if (threadsPerAtom == ThreadsPerAtom::Order)
     {
-        kernelPtr = pmeGpu->programHandle_->impl_->splineKernelThPerAtom4;
+        if (numGrids == 2)
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelThPerAtom4Dual;
+        }
+        else
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelThPerAtom4Single;
+        }
     }
     else
     {
-        kernelPtr = pmeGpu->programHandle_->impl_->splineKernel;
+        if (numGrids == 2)
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelDual;
+        }
+        else
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelSingle;
+        }
     }
     return kernelPtr;
 }
@@ -1129,66 +1213,106 @@ static auto selectSplineKernelPtr(const PmeGpu* pmeGpu, bool useOrderThreadsPerA
  * Returns a pointer to appropriate spread kernel based on the input bool values
  *
  * \param[in]  pmeGpu                   The PME GPU structure.
- * \param[in]  useOrderThreadsPerAtom   bool controlling if we should use order or order*order threads per atom
+ * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  writeSplinesToGlobal     bool controlling if we should write spline data to global memory
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
-static auto selectSpreadKernelPtr(const PmeGpu* pmeGpu, bool useOrderThreadsPerAtom, bool writeSplinesToGlobal)
+static auto selectSpreadKernelPtr(const PmeGpu*  pmeGpu,
+                                  ThreadsPerAtom threadsPerAtom,
+                                  bool           writeSplinesToGlobal,
+                                  const int      numGrids)
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (writeSplinesToGlobal)
     {
-        if (useOrderThreadsPerAtom)
+        if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->spreadKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelSingle;
+            }
         }
     }
     else
     {
         /* if we are not saving the spline data we need to recalculate it
            using the spline and spread Kernel */
-        if (useOrderThreadsPerAtom)
+        if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelSingle;
+            }
         }
     }
     return kernelPtr;
 }
 
-void pme_gpu_spread(const PmeGpu*         pmeGpu,
-                    GpuEventSynchronizer* xReadyOnDevice,
-                    int gmx_unused gridIndex,
-                    real*          h_grid,
-                    bool           computeSplines,
-                    bool           spreadCharges)
+void pme_gpu_spread(const PmeGpu*                  pmeGpu,
+                    GpuEventSynchronizer*          xReadyOnDevice,
+                    real**                         h_grids,
+                    bool                           computeSplines,
+                    bool                           spreadCharges,
+                    const real                     lambda,
+                    const bool                     useGpuDirectComm,
+                    gmx::PmeCoordinateReceiverGpu* pmeCoordinateReceiverGpu)
 {
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+
     GMX_ASSERT(computeSplines || spreadCharges,
                "PME spline/spread kernel has invalid input (nothing to do)");
-    const auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+    auto* kernelParamsPtr = pmeGpu->kernelParams.get();
     GMX_ASSERT(kernelParamsPtr->atoms.nAtoms > 0, "No atom data in PME GPU spread");
 
     const size_t blockSize = pmeGpu->programHandle_->impl_->spreadWorkGroupSize;
 
     const int order = pmeGpu->common->pme_order;
     GMX_ASSERT(order == c_pmeGpuOrder, "Only PME order 4 is implemented");
-    const bool writeGlobal            = pmeGpu->settings.copyAllOutputs;
-    const bool useOrderThreadsPerAtom = pmeGpu->settings.useOrderThreadsPerAtom;
-    const bool recalculateSplines     = pmeGpu->settings.recalculateSplines;
-#if GMX_GPU == GMX_GPU_OPENCL
-    GMX_ASSERT(!useOrderThreadsPerAtom, "Only 16 threads per atom supported in OpenCL");
-    GMX_ASSERT(!recalculateSplines, "Recalculating splines not supported in OpenCL");
-#endif
-    const int atomsPerBlock = useOrderThreadsPerAtom ? blockSize / c_pmeSpreadGatherThreadsPerAtom4ThPerAtom
-                                                     : blockSize / c_pmeSpreadGatherThreadsPerAtom;
+    const bool writeGlobal = pmeGpu->settings.copyAllOutputs;
+    const int  threadsPerAtom =
+            (pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::Order ? order : order * order);
+    const bool recalculateSplines = pmeGpu->settings.recalculateSplines;
+
+    GMX_ASSERT(!GMX_GPU_OPENCL || pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::OrderSquared,
+               "Only 16 threads per atom supported in OpenCL");
+    GMX_ASSERT(!GMX_GPU_OPENCL || !recalculateSplines,
+               "Recalculating splines not supported in OpenCL");
+
+    const int atomsPerBlock = blockSize / threadsPerAtom;
 
     // TODO: pick smaller block size in runtime if needed
     // (e.g. on 660 Ti where 50% occupancy is ~25% faster than 100% occupancy with RNAse (~17.8k atoms))
@@ -1196,77 +1320,179 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
     // TODO: test varying block sizes on modern arch-s as well
     // TODO: also consider using cudaFuncSetCacheConfig() for preferring shared memory on older architectures
     //(for spline data mostly)
-    GMX_ASSERT(!c_usePadding || !(c_pmeAtomDataAlignment % atomsPerBlock),
+    GMX_ASSERT(!(c_pmeAtomDataBlockSize % atomsPerBlock),
                "inconsistent atom data padding vs. spreading block size");
 
     // Ensure that coordinates are ready on the device before launching spread;
-    // only needed with CUDA on PP+PME ranks, not on separate PME ranks, in unit tests
-    // nor in OpenCL as these cases use a single stream (hence xReadyOnDevice == nullptr).
-    GMX_ASSERT(xReadyOnDevice != nullptr || (GMX_GPU != GMX_GPU_CUDA)
-                       || pmeGpu->common->isRankPmeOnly || pme_gpu_settings(pmeGpu).copyAllOutputs,
+    // only needed on PP+PME ranks, not on separate PME ranks, in unit tests
+    // as these cases use a single stream (hence xReadyOnDevice == nullptr).
+    GMX_ASSERT(xReadyOnDevice != nullptr || pmeGpu->common->isRankPmeOnly
+                       || pme_gpu_settings(pmeGpu).copyAllOutputs,
                "Need a valid coordinate synchronizer on PP+PME ranks with CUDA.");
+
     if (xReadyOnDevice)
     {
-        xReadyOnDevice->enqueueWaitEvent(pmeGpu->archSpecific->pmeStream);
+        xReadyOnDevice->enqueueWaitEvent(pmeGpu->archSpecific->pmeStream_);
     }
 
-    const int blockCount = pmeGpu->nAtomsPadded / atomsPerBlock;
+    const int blockCount = pmeGpu->nAtomsAlloc / atomsPerBlock;
     auto      dimGrid    = pmeGpuCreateGrid(pmeGpu, blockCount);
 
+    if (pmeGpu->common->ngrids == 1)
+    {
+        kernelParamsPtr->current.scale = 1.0;
+    }
+    else
+    {
+        kernelParamsPtr->current.scale = 1.0 - lambda;
+    }
+
     KernelLaunchConfig config;
     config.blockSize[0] = order;
-    config.blockSize[1] = useOrderThreadsPerAtom ? 1 : order;
+    config.blockSize[1] = (pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::Order ? 1 : order);
     config.blockSize[2] = atomsPerBlock;
     config.gridSize[0]  = dimGrid.first;
     config.gridSize[1]  = dimGrid.second;
-    config.stream       = pmeGpu->archSpecific->pmeStream;
 
-    int                                timingId;
+    PmeStage                           timingId;
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
+    const bool writeGlobalOrSaveSplines          = writeGlobal || (!recalculateSplines);
     if (computeSplines)
     {
         if (spreadCharges)
         {
-            timingId  = gtPME_SPLINEANDSPREAD;
-            kernelPtr = selectSplineAndSpreadKernelPtr(pmeGpu, useOrderThreadsPerAtom,
-                                                       writeGlobal || (!recalculateSplines));
+            timingId  = PmeStage::SplineAndSpread;
+            kernelPtr = selectSplineAndSpreadKernelPtr(pmeGpu,
+                                                       pmeGpu->settings.threadsPerAtom,
+                                                       writeGlobalOrSaveSplines,
+                                                       pmeGpu->common->ngrids);
         }
         else
         {
-            timingId  = gtPME_SPLINE;
-            kernelPtr = selectSplineKernelPtr(pmeGpu, useOrderThreadsPerAtom,
-                                              writeGlobal || (!recalculateSplines));
+            timingId  = PmeStage::Spline;
+            kernelPtr = selectSplineKernelPtr(pmeGpu,
+                                              pmeGpu->settings.threadsPerAtom,
+                                              writeGlobalOrSaveSplines,
+                                              pmeGpu->common->ngrids);
         }
     }
     else
     {
-        timingId  = gtPME_SPREAD;
-        kernelPtr = selectSpreadKernelPtr(pmeGpu, useOrderThreadsPerAtom,
-                                          writeGlobal || (!recalculateSplines));
+        timingId  = PmeStage::Spread;
+        kernelPtr = selectSpreadKernelPtr(
+                pmeGpu, pmeGpu->settings.threadsPerAtom, writeGlobalOrSaveSplines, pmeGpu->common->ngrids);
     }
 
 
     pme_gpu_start_timing(pmeGpu, timingId);
     auto* timingEvent = pme_gpu_fetch_timing_event(pmeGpu, timingId);
+
+    kernelParamsPtr->usePipeline = char(computeSplines && spreadCharges && useGpuDirectComm
+                                        && (pmeCoordinateReceiverGpu->ppCommNumSenderRanks() > 1)
+                                        && !writeGlobalOrSaveSplines);
+    if (kernelParamsPtr->usePipeline != 0)
+    {
+        int numStagesInPipeline = pmeCoordinateReceiverGpu->ppCommNumSenderRanks();
+
+        for (int i = 0; i < numStagesInPipeline; i++)
+        {
+            int senderRank;
+            if (useGpuDirectComm)
+            {
+                senderRank = pmeCoordinateReceiverGpu->synchronizeOnCoordinatesFromPpRank(
+                        i, *(pmeCoordinateReceiverGpu->ppCommStream(i)));
+            }
+            else
+            {
+                senderRank = i;
+            }
+
+            // set kernel configuration options specific to this stage of the pipeline
+            std::tie(kernelParamsPtr->pipelineAtomStart, kernelParamsPtr->pipelineAtomEnd) =
+                    pmeCoordinateReceiverGpu->ppCommAtomRange(senderRank);
+            const int blockCount       = static_cast<int>(std::ceil(
+                    static_cast<float>(kernelParamsPtr->pipelineAtomEnd - kernelParamsPtr->pipelineAtomStart)
+                    / atomsPerBlock));
+            auto      dimGrid          = pmeGpuCreateGrid(pmeGpu, blockCount);
+            config.gridSize[0]         = dimGrid.first;
+            config.gridSize[1]         = dimGrid.second;
+            DeviceStream* launchStream = pmeCoordinateReceiverGpu->ppCommStream(senderRank);
+
+
 #if c_canEmbedBuffers
-    const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
+            const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
+#else
+            const auto kernelArgs =
+                    prepareGpuKernelArguments(kernelPtr,
+                                              config,
+                                              kernelParamsPtr,
+                                              &kernelParamsPtr->atoms.d_theta,
+                                              &kernelParamsPtr->atoms.d_dtheta,
+                                              &kernelParamsPtr->atoms.d_gridlineIndices,
+                                              &kernelParamsPtr->grid.d_realGrid[FEP_STATE_A],
+                                              &kernelParamsPtr->grid.d_realGrid[FEP_STATE_B],
+                                              &kernelParamsPtr->grid.d_fractShiftsTable,
+                                              &kernelParamsPtr->grid.d_gridlineIndicesTable,
+                                              &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_A],
+                                              &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_B],
+                                              &kernelParamsPtr->atoms.d_coordinates);
+#endif
+
+            launchGpuKernel(kernelPtr, config, *launchStream, timingEvent, "PME spline/spread", kernelArgs);
+        }
+        // Set dependencies for PME stream on all pipeline streams
+        for (int i = 0; i < pmeCoordinateReceiverGpu->ppCommNumSenderRanks(); i++)
+        {
+            GpuEventSynchronizer event;
+            event.markEvent(*(pmeCoordinateReceiverGpu->ppCommStream(i)));
+            event.enqueueWaitEvent(pmeGpu->archSpecific->pmeStream_);
+        }
+    }
+    else // pipelining is not in use
+    {
+        if (useGpuDirectComm) // Sync all PME-PP communications to PME stream
+        {
+            pmeCoordinateReceiverGpu->synchronizeOnCoordinatesFromAllPpRanks(pmeGpu->archSpecific->pmeStream_);
+        }
+
+#if c_canEmbedBuffers
+        const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
 #else
-    const auto kernelArgs = prepareGpuKernelArguments(
-            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->atoms.d_theta,
-            &kernelParamsPtr->atoms.d_dtheta, &kernelParamsPtr->atoms.d_gridlineIndices,
-            &kernelParamsPtr->grid.d_realGrid, &kernelParamsPtr->grid.d_fractShiftsTable,
-            &kernelParamsPtr->grid.d_gridlineIndicesTable, &kernelParamsPtr->atoms.d_coefficients,
-            &kernelParamsPtr->atoms.d_coordinates);
+        const auto kernelArgs =
+                prepareGpuKernelArguments(kernelPtr,
+                                          config,
+                                          kernelParamsPtr,
+                                          &kernelParamsPtr->atoms.d_theta,
+                                          &kernelParamsPtr->atoms.d_dtheta,
+                                          &kernelParamsPtr->atoms.d_gridlineIndices,
+                                          &kernelParamsPtr->grid.d_realGrid[FEP_STATE_A],
+                                          &kernelParamsPtr->grid.d_realGrid[FEP_STATE_B],
+                                          &kernelParamsPtr->grid.d_fractShiftsTable,
+                                          &kernelParamsPtr->grid.d_gridlineIndicesTable,
+                                          &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_A],
+                                          &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_B],
+                                          &kernelParamsPtr->atoms.d_coordinates);
 #endif
 
-    launchGpuKernel(kernelPtr, config, timingEvent, "PME spline/spread", kernelArgs);
+        launchGpuKernel(kernelPtr,
+                        config,
+                        pmeGpu->archSpecific->pmeStream_,
+                        timingEvent,
+                        "PME spline/spread",
+                        kernelArgs);
+    }
+
     pme_gpu_stop_timing(pmeGpu, timingId);
 
     const auto& settings    = pmeGpu->settings;
     const bool copyBackGrid = spreadCharges && (!settings.performGPUFFT || settings.copyAllOutputs);
     if (copyBackGrid)
     {
-        pme_gpu_copy_output_spread_grid(pmeGpu, h_grid);
+        for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+        {
+            float* h_grid = h_grids[gridIndex];
+            pme_gpu_copy_output_spread_grid(pmeGpu, h_grid, gridIndex);
+        }
     }
     const bool copyBackAtomData =
             computeSplines && (!settings.performGPUGather || settings.copyAllOutputs);
@@ -1276,8 +1502,18 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
     }
 }
 
-void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrdering, bool computeEnergyAndVirial)
+void pme_gpu_solve(const PmeGpu* pmeGpu,
+                   const int     gridIndex,
+                   t_complex*    h_grid,
+                   GridOrdering  gridOrdering,
+                   bool          computeEnergyAndVirial)
 {
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+    GMX_ASSERT(gridIndex < pmeGpu->common->ngrids,
+               "Invalid combination of gridIndex and number of grids");
+
     const auto& settings               = pmeGpu->settings;
     const bool  copyInputAndOutputGrid = !settings.performGPUFFT || settings.copyAllOutputs;
 
@@ -1286,9 +1522,13 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
     float* h_gridFloat = reinterpret_cast<float*>(h_grid);
     if (copyInputAndOutputGrid)
     {
-        copyToDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid, h_gridFloat, 0,
-                           pmeGpu->archSpecific->complexGridSize, pmeGpu->archSpecific->pmeStream,
-                           pmeGpu->settings.transferKind, nullptr);
+        copyToDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid[gridIndex],
+                           h_gridFloat,
+                           0,
+                           pmeGpu->archSpecific->complexGridSize[gridIndex],
+                           pmeGpu->archSpecific->pmeStream_,
+                           pmeGpu->settings.transferKind,
+                           nullptr);
     }
 
     int majorDim = -1, middleDim = -1, minorDim = -1;
@@ -1323,10 +1563,10 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
     {
         cellsPerBlock = (gridLineSize + blocksPerGridLine - 1) / blocksPerGridLine;
     }
-    const int warpSize  = pmeGpu->programHandle_->impl_->warpSize;
+    const int warpSize  = pmeGpu->programHandle_->warpSize();
     const int blockSize = (cellsPerBlock + warpSize - 1) / warpSize * warpSize;
 
-    static_assert(GMX_GPU != GMX_GPU_CUDA || c_solveMaxWarpsPerBlock / 2 >= 4,
+    static_assert(!GMX_GPU_CUDA || c_solveMaxWarpsPerBlock / 2 >= 4,
                   "The CUDA solve energy kernels needs at least 4 warps. "
                   "Here we launch at least half of the max warps.");
 
@@ -1337,19 +1577,34 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
     config.gridSize[1] = (pmeGpu->kernelParams->grid.complexGridSize[middleDim] + gridLinesPerBlock - 1)
                          / gridLinesPerBlock;
     config.gridSize[2] = pmeGpu->kernelParams->grid.complexGridSize[majorDim];
-    config.stream      = pmeGpu->archSpecific->pmeStream;
 
-    int                                timingId  = gtPME_SOLVE;
+    PmeStage                           timingId  = PmeStage::Solve;
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (gridOrdering == GridOrdering::YZX)
     {
-        kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveYZXEnergyKernel
-                                           : pmeGpu->programHandle_->impl_->solveYZXKernel;
+        if (gridIndex == 0)
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveYZXEnergyKernelA
+                                               : pmeGpu->programHandle_->impl_->solveYZXKernelA;
+        }
+        else
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveYZXEnergyKernelB
+                                               : pmeGpu->programHandle_->impl_->solveYZXKernelB;
+        }
     }
     else if (gridOrdering == GridOrdering::XYZ)
     {
-        kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveXYZEnergyKernel
-                                           : pmeGpu->programHandle_->impl_->solveXYZKernel;
+        if (gridIndex == 0)
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveXYZEnergyKernelA
+                                               : pmeGpu->programHandle_->impl_->solveXYZKernelA;
+        }
+        else
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveXYZEnergyKernelB
+                                               : pmeGpu->programHandle_->impl_->solveXYZKernelB;
+        }
     }
 
     pme_gpu_start_timing(pmeGpu, timingId);
@@ -1357,25 +1612,37 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
 #if c_canEmbedBuffers
     const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
 #else
-    const auto kernelArgs = prepareGpuKernelArguments(
-            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->grid.d_splineModuli,
-            &kernelParamsPtr->constants.d_virialAndEnergy, &kernelParamsPtr->grid.d_fourierGrid);
+    const auto kernelArgs =
+            prepareGpuKernelArguments(kernelPtr,
+                                      config,
+                                      kernelParamsPtr,
+                                      &kernelParamsPtr->grid.d_splineModuli[gridIndex],
+                                      &kernelParamsPtr->constants.d_virialAndEnergy[gridIndex],
+                                      &kernelParamsPtr->grid.d_fourierGrid[gridIndex]);
 #endif
-    launchGpuKernel(kernelPtr, config, timingEvent, "PME solve", kernelArgs);
+    launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent, "PME solve", kernelArgs);
     pme_gpu_stop_timing(pmeGpu, timingId);
 
     if (computeEnergyAndVirial)
     {
-        copyFromDeviceBuffer(pmeGpu->staging.h_virialAndEnergy,
-                             &kernelParamsPtr->constants.d_virialAndEnergy, 0, c_virialAndEnergyCount,
-                             pmeGpu->archSpecific->pmeStream, pmeGpu->settings.transferKind, nullptr);
+        copyFromDeviceBuffer(pmeGpu->staging.h_virialAndEnergy[gridIndex],
+                             &kernelParamsPtr->constants.d_virialAndEnergy[gridIndex],
+                             0,
+                             c_virialAndEnergyCount,
+                             pmeGpu->archSpecific->pmeStream_,
+                             pmeGpu->settings.transferKind,
+                             nullptr);
     }
 
     if (copyInputAndOutputGrid)
     {
-        copyFromDeviceBuffer(h_gridFloat, &kernelParamsPtr->grid.d_fourierGrid, 0,
-                             pmeGpu->archSpecific->complexGridSize, pmeGpu->archSpecific->pmeStream,
-                             pmeGpu->settings.transferKind, nullptr);
+        copyFromDeviceBuffer(h_gridFloat,
+                             &kernelParamsPtr->grid.d_fourierGrid[gridIndex],
+                             0,
+                             pmeGpu->archSpecific->complexGridSize[gridIndex],
+                             pmeGpu->archSpecific->pmeStream_,
+                             pmeGpu->settings.transferKind,
+                             nullptr);
     }
 }
 
@@ -1383,66 +1650,88 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
  * Returns a pointer to appropriate gather kernel based on the inputvalues
  *
  * \param[in]  pmeGpu                   The PME GPU structure.
- * \param[in]  useOrderThreadsPerAtom   bool controlling if we should use order or order*order threads per atom
+ * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  readSplinesFromGlobal    bool controlling if we should write spline data to global memory
- * \param[in]  forceTreatment           Controls if the forces from the gather should increment or replace the input forces.
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
-inline auto selectGatherKernelPtr(const PmeGpu*          pmeGpu,
-                                  bool                   useOrderThreadsPerAtom,
-                                  bool                   readSplinesFromGlobal,
-                                  PmeForceOutputHandling forceTreatment)
+inline auto selectGatherKernelPtr(const PmeGpu*  pmeGpu,
+                                  ThreadsPerAtom threadsPerAtom,
+                                  bool           readSplinesFromGlobal,
+                                  const int      numGrids)
 
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
 
     if (readSplinesFromGlobal)
     {
-        if (useOrderThreadsPerAtom)
+        if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = (forceTreatment == PmeForceOutputHandling::Set)
-                                ? pmeGpu->programHandle_->impl_->gatherKernelReadSplinesThPerAtom4
-                                : pmeGpu->programHandle_->impl_->gatherReduceWithInputKernelReadSplinesThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = (forceTreatment == PmeForceOutputHandling::Set)
-                                ? pmeGpu->programHandle_->impl_->gatherKernelReadSplines
-                                : pmeGpu->programHandle_->impl_->gatherReduceWithInputKernelReadSplines;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesSingle;
+            }
         }
     }
     else
     {
-        if (useOrderThreadsPerAtom)
+        if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = (forceTreatment == PmeForceOutputHandling::Set)
-                                ? pmeGpu->programHandle_->impl_->gatherKernelThPerAtom4
-                                : pmeGpu->programHandle_->impl_->gatherReduceWithInputKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = (forceTreatment == PmeForceOutputHandling::Set)
-                                ? pmeGpu->programHandle_->impl_->gatherKernel
-                                : pmeGpu->programHandle_->impl_->gatherReduceWithInputKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelSingle;
+            }
         }
     }
     return kernelPtr;
 }
 
-
-void pme_gpu_gather(PmeGpu* pmeGpu, PmeForceOutputHandling forceTreatment, const float* h_grid)
+void pme_gpu_gather(PmeGpu* pmeGpu, real** h_grids, const float lambda)
 {
-    /* Copying the input CPU forces for reduction */
-    if (forceTreatment != PmeForceOutputHandling::Set)
-    {
-        pme_gpu_copy_input_forces(pmeGpu);
-    }
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
 
     const auto& settings = pmeGpu->settings;
+
     if (!settings.performGPUFFT || settings.copyAllOutputs)
     {
-        pme_gpu_copy_input_gather_grid(pmeGpu, const_cast<float*>(h_grid));
+        for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+        {
+            float* h_grid = const_cast<float*>(h_grids[gridIndex]);
+            pme_gpu_copy_input_gather_grid(pmeGpu, h_grid, gridIndex);
+        }
     }
 
     if (settings.copyAllOutputs)
@@ -1451,59 +1740,78 @@ void pme_gpu_gather(PmeGpu* pmeGpu, PmeForceOutputHandling forceTreatment, const
     }
 
     /* Set if we have unit tests */
-    const bool   readGlobal             = pmeGpu->settings.copyAllOutputs;
-    const size_t blockSize              = pmeGpu->programHandle_->impl_->gatherWorkGroupSize;
-    const bool   useOrderThreadsPerAtom = pmeGpu->settings.useOrderThreadsPerAtom;
-    const bool   recalculateSplines     = pmeGpu->settings.recalculateSplines;
-#if GMX_GPU == GMX_GPU_OPENCL
-    GMX_ASSERT(!useOrderThreadsPerAtom, "Only 16 threads per atom supported in OpenCL");
-    GMX_ASSERT(!recalculateSplines, "Recalculating splines not supported in OpenCL");
-#endif
-    const int atomsPerBlock = useOrderThreadsPerAtom ? blockSize / c_pmeSpreadGatherThreadsPerAtom4ThPerAtom
-                                                     : blockSize / c_pmeSpreadGatherThreadsPerAtom;
+    const bool   readGlobal = pmeGpu->settings.copyAllOutputs;
+    const size_t blockSize  = pmeGpu->programHandle_->impl_->gatherWorkGroupSize;
+    const int    order      = pmeGpu->common->pme_order;
+    GMX_ASSERT(order == c_pmeGpuOrder, "Only PME order 4 is implemented");
+    const int threadsPerAtom =
+            (pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::Order ? order : order * order);
+    const bool recalculateSplines = pmeGpu->settings.recalculateSplines;
 
-    GMX_ASSERT(!c_usePadding || !(c_pmeAtomDataAlignment % atomsPerBlock),
+    GMX_ASSERT(!GMX_GPU_OPENCL || pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::OrderSquared,
+               "Only 16 threads per atom supported in OpenCL");
+    GMX_ASSERT(!GMX_GPU_OPENCL || !recalculateSplines,
+               "Recalculating splines not supported in OpenCL");
+
+    const int atomsPerBlock = blockSize / threadsPerAtom;
+
+    GMX_ASSERT(!(c_pmeAtomDataBlockSize % atomsPerBlock),
                "inconsistent atom data padding vs. gathering block size");
 
-    const int blockCount = pmeGpu->nAtomsPadded / atomsPerBlock;
+    const int blockCount = pmeGpu->nAtomsAlloc / atomsPerBlock;
     auto      dimGrid    = pmeGpuCreateGrid(pmeGpu, blockCount);
 
-    const int order = pmeGpu->common->pme_order;
-    GMX_ASSERT(order == c_pmeGpuOrder, "Only PME order 4 is implemented");
-
     KernelLaunchConfig config;
     config.blockSize[0] = order;
-    config.blockSize[1] = useOrderThreadsPerAtom ? 1 : order;
+    config.blockSize[1] = (pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::Order ? 1 : order);
     config.blockSize[2] = atomsPerBlock;
     config.gridSize[0]  = dimGrid.first;
     config.gridSize[1]  = dimGrid.second;
-    config.stream       = pmeGpu->archSpecific->pmeStream;
 
     // TODO test different cache configs
 
-    int                                timingId  = gtPME_GATHER;
-    PmeGpuProgramImpl::PmeKernelHandle kernelPtr = selectGatherKernelPtr(
-            pmeGpu, useOrderThreadsPerAtom, readGlobal || (!recalculateSplines), forceTreatment);
+    PmeStage                           timingId = PmeStage::Gather;
+    PmeGpuProgramImpl::PmeKernelHandle kernelPtr =
+            selectGatherKernelPtr(pmeGpu,
+                                  pmeGpu->settings.threadsPerAtom,
+                                  readGlobal || (!recalculateSplines),
+                                  pmeGpu->common->ngrids);
     // TODO design kernel selection getters and make PmeGpu a friend of PmeGpuProgramImpl
 
     pme_gpu_start_timing(pmeGpu, timingId);
-    auto*       timingEvent     = pme_gpu_fetch_timing_event(pmeGpu, timingId);
-    const auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+    auto* timingEvent     = pme_gpu_fetch_timing_event(pmeGpu, timingId);
+    auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+    if (pmeGpu->common->ngrids == 1)
+    {
+        kernelParamsPtr->current.scale = 1.0;
+    }
+    else
+    {
+        kernelParamsPtr->current.scale = 1.0 - lambda;
+    }
+
 #if c_canEmbedBuffers
     const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
 #else
-    const auto kernelArgs = prepareGpuKernelArguments(
-            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->atoms.d_coefficients,
-            &kernelParamsPtr->grid.d_realGrid, &kernelParamsPtr->atoms.d_theta,
-            &kernelParamsPtr->atoms.d_dtheta, &kernelParamsPtr->atoms.d_gridlineIndices,
-            &kernelParamsPtr->atoms.d_forces);
+    const auto kernelArgs =
+            prepareGpuKernelArguments(kernelPtr,
+                                      config,
+                                      kernelParamsPtr,
+                                      &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_A],
+                                      &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_B],
+                                      &kernelParamsPtr->grid.d_realGrid[FEP_STATE_A],
+                                      &kernelParamsPtr->grid.d_realGrid[FEP_STATE_B],
+                                      &kernelParamsPtr->atoms.d_theta,
+                                      &kernelParamsPtr->atoms.d_dtheta,
+                                      &kernelParamsPtr->atoms.d_gridlineIndices,
+                                      &kernelParamsPtr->atoms.d_forces);
 #endif
-    launchGpuKernel(kernelPtr, config, timingEvent, "PME gather", kernelArgs);
+    launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent, "PME gather", kernelArgs);
     pme_gpu_stop_timing(pmeGpu, timingId);
 
     if (pmeGpu->settings.useGpuForceReduction)
     {
-        pmeGpu->archSpecific->pmeForcesReady.markEvent(pmeGpu->archSpecific->pmeStream);
+        pmeGpu->archSpecific->pmeForcesReady.markEvent(pmeGpu->archSpecific->pmeStream_);
     }
     else
     {
@@ -1511,7 +1819,7 @@ void pme_gpu_gather(PmeGpu* pmeGpu, PmeForceOutputHandling forceTreatment, const
     }
 }
 
-void* pme_gpu_get_kernelparam_forces(const PmeGpu* pmeGpu)
+DeviceBuffer<gmx::RVec> pme_gpu_get_kernelparam_forces(const PmeGpu* pmeGpu)
 {
     if (pmeGpu && pmeGpu->kernelParams)
     {
@@ -1519,11 +1827,11 @@ void* pme_gpu_get_kernelparam_forces(const PmeGpu* pmeGpu)
     }
     else
     {
-        return nullptr;
+        return DeviceBuffer<gmx::RVec>{};
     }
 }
 
-void pme_gpu_set_kernelparam_coordinates(const PmeGpu* pmeGpu, DeviceBuffer<float> d_x)
+void pme_gpu_set_kernelparam_coordinates(const PmeGpu* pmeGpu, DeviceBuffer<gmx::RVec> d_x)
 {
     GMX_ASSERT(pmeGpu && pmeGpu->kernelParams,
                "PME GPU device buffer can not be set in non-GPU builds or before the GPU PME was "
@@ -1535,30 +1843,6 @@ void pme_gpu_set_kernelparam_coordinates(const PmeGpu* pmeGpu, DeviceBuffer<floa
     pmeGpu->kernelParams->atoms.d_coordinates = d_x;
 }
 
-void* pme_gpu_get_stream(const PmeGpu* pmeGpu)
-{
-    if (pmeGpu)
-    {
-        return static_cast<void*>(&pmeGpu->archSpecific->pmeStream);
-    }
-    else
-    {
-        return nullptr;
-    }
-}
-
-void* pme_gpu_get_context(const PmeGpu* pmeGpu)
-{
-    if (pmeGpu)
-    {
-        return static_cast<void*>(&pmeGpu->archSpecific->context);
-    }
-    else
-    {
-        return nullptr;
-    }
-}
-
 GpuEventSynchronizer* pme_gpu_get_forces_ready_synchronizer(const PmeGpu* pmeGpu)
 {
     if (pmeGpu && pmeGpu->kernelParams)