Unify coordinate copy handling across GPU platforms
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gpu_internal.cpp
index 63b77aa86f81326aee15640c164f02ee64722ab1..5e5bfd8ebeac7cdd9b6cf22ebff2f2dbf0cfc1ab 100644 (file)
@@ -1,7 +1,8 @@
 /*
  * This file is part of the GROMACS molecular simulation package.
  *
- * Copyright (c) 2016,2017,2018,2019,2020, by the GROMACS development team, led by
+ * Copyright (c) 2016,2017,2018,2019,2020 by the GROMACS development team.
+ * Copyright (c) 2021, by the GROMACS development team, led by
  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
  * and including many others, as listed in the AUTHORS file in the
  * top-level source directory and at http://www.gromacs.org.
 #include <string>
 
 #include "gromacs/ewald/ewald_utils.h"
+#include "gromacs/fft/gpu_3dfft.h"
 #include "gromacs/gpu_utils/device_context.h"
 #include "gromacs/gpu_utils/device_stream.h"
 #include "gromacs/gpu_utils/gpu_utils.h"
+#include "gromacs/gpu_utils/pmalloc.h"
+#if GMX_GPU_SYCL
+#    include "gromacs/gpu_utils/syclutils.h"
+#endif
+#include "gromacs/hardware/device_information.h"
 #include "gromacs/math/invertmatrix.h"
 #include "gromacs/math/units.h"
 #include "gromacs/timing/gpu_timing.h"
 #include "gromacs/utility/gmxassert.h"
 #include "gromacs/utility/logger.h"
 #include "gromacs/utility/stringutil.h"
+#include "gromacs/ewald/pme.h"
+#include "gromacs/ewald/pme_coordinate_receiver_gpu.h"
 
-#if GMX_GPU == GMX_GPU_CUDA
-#    include "gromacs/gpu_utils/pmalloc_cuda.h"
-
+#if GMX_GPU_CUDA
 #    include "pme.cuh"
-#elif GMX_GPU == GMX_GPU_OPENCL
-#    include "gromacs/gpu_utils/gmxopencl.h"
 #endif
 
-#include "gromacs/ewald/pme.h"
-
-#include "pme_gpu_3dfft.h"
 #include "pme_gpu_calculate_splines.h"
 #include "pme_gpu_constants.h"
 #include "pme_gpu_program_impl.h"
@@ -93,7 +95,7 @@
 /*! \brief
  * CUDA only
  * Atom limit above which it is advantageous to turn on the
- * recalcuating of the splines in the gather and using less threads per atom in the spline and spread
+ * recalculating of the splines in the gather and using less threads per atom in the spline and spread
  */
 constexpr int c_pmeGpuPerformanceAtomLimit = 23000;
 
@@ -133,27 +135,51 @@ void pme_gpu_synchronize(const PmeGpu* pmeGpu)
 void pme_gpu_alloc_energy_virial(PmeGpu* pmeGpu)
 {
     const size_t energyAndVirialSize = c_virialAndEnergyCount * sizeof(float);
-    allocateDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy, c_virialAndEnergyCount,
-                         pmeGpu->archSpecific->deviceContext_);
-    pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_virialAndEnergy), energyAndVirialSize);
+
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        allocateDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy[gridIndex],
+                             c_virialAndEnergyCount,
+                             pmeGpu->archSpecific->deviceContext_);
+        pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_virialAndEnergy[gridIndex]), energyAndVirialSize);
+    }
 }
 
 void pme_gpu_free_energy_virial(PmeGpu* pmeGpu)
 {
-    freeDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy);
-    pfree(pmeGpu->staging.h_virialAndEnergy);
-    pmeGpu->staging.h_virialAndEnergy = nullptr;
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        freeDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy[gridIndex]);
+        pfree(pmeGpu->staging.h_virialAndEnergy[gridIndex]);
+        pmeGpu->staging.h_virialAndEnergy[gridIndex] = nullptr;
+    }
 }
 
 void pme_gpu_clear_energy_virial(const PmeGpu* pmeGpu)
 {
-    clearDeviceBufferAsync(&pmeGpu->kernelParams->constants.d_virialAndEnergy, 0,
-                           c_virialAndEnergyCount, pmeGpu->archSpecific->pmeStream_);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        clearDeviceBufferAsync(&pmeGpu->kernelParams->constants.d_virialAndEnergy[gridIndex],
+                               0,
+                               c_virialAndEnergyCount,
+                               pmeGpu->archSpecific->pmeStream_);
+    }
 }
 
-void pme_gpu_realloc_and_copy_bspline_values(PmeGpu* pmeGpu)
+void pme_gpu_realloc_and_copy_bspline_values(PmeGpu* pmeGpu, const int gridIndex)
 {
-    const int splineValuesOffset[DIM] = { 0, pmeGpu->kernelParams->grid.realGridSize[XX],
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+    GMX_ASSERT(gridIndex < pmeGpu->common->ngrids,
+               "Invalid combination of gridIndex and number of grids");
+
+    const int splineValuesOffset[DIM] = { 0,
+                                          pmeGpu->kernelParams->grid.realGridSize[XX],
                                           pmeGpu->kernelParams->grid.realGridSize[XX]
                                                   + pmeGpu->kernelParams->grid.realGridSize[YY] };
     memcpy(&pmeGpu->kernelParams->grid.splineValuesOffset, &splineValuesOffset, sizeof(splineValuesOffset));
@@ -161,41 +187,52 @@ void pme_gpu_realloc_and_copy_bspline_values(PmeGpu* pmeGpu)
     const int newSplineValuesSize = pmeGpu->kernelParams->grid.realGridSize[XX]
                                     + pmeGpu->kernelParams->grid.realGridSize[YY]
                                     + pmeGpu->kernelParams->grid.realGridSize[ZZ];
-    const bool shouldRealloc = (newSplineValuesSize > pmeGpu->archSpecific->splineValuesSize);
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli, newSplineValuesSize,
-                           &pmeGpu->archSpecific->splineValuesSize,
-                           &pmeGpu->archSpecific->splineValuesSizeAlloc,
+    const bool shouldRealloc = (newSplineValuesSize > pmeGpu->archSpecific->splineValuesSize[gridIndex]);
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli[gridIndex],
+                           newSplineValuesSize,
+                           &pmeGpu->archSpecific->splineValuesSize[gridIndex],
+                           &pmeGpu->archSpecific->splineValuesCapacity[gridIndex],
                            pmeGpu->archSpecific->deviceContext_);
     if (shouldRealloc)
     {
         /* Reallocate the host buffer */
-        pfree(pmeGpu->staging.h_splineModuli);
-        pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_splineModuli),
+        pfree(pmeGpu->staging.h_splineModuli[gridIndex]);
+        pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_splineModuli[gridIndex]),
                 newSplineValuesSize * sizeof(float));
     }
     for (int i = 0; i < DIM; i++)
     {
-        memcpy(pmeGpu->staging.h_splineModuli + splineValuesOffset[i],
-               pmeGpu->common->bsp_mod[i].data(), pmeGpu->common->bsp_mod[i].size() * sizeof(float));
+        memcpy(pmeGpu->staging.h_splineModuli[gridIndex] + splineValuesOffset[i],
+               pmeGpu->common->bsp_mod[i].data(),
+               pmeGpu->common->bsp_mod[i].size() * sizeof(float));
     }
     /* TODO: pin original buffer instead! */
-    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli, pmeGpu->staging.h_splineModuli,
-                       0, newSplineValuesSize, pmeGpu->archSpecific->pmeStream_,
-                       pmeGpu->settings.transferKind, nullptr);
+    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli[gridIndex],
+                       pmeGpu->staging.h_splineModuli[gridIndex],
+                       0,
+                       newSplineValuesSize,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
 void pme_gpu_free_bspline_values(const PmeGpu* pmeGpu)
 {
-    pfree(pmeGpu->staging.h_splineModuli);
-    freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        pfree(pmeGpu->staging.h_splineModuli[gridIndex]);
+        freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli[gridIndex]);
+    }
 }
 
 void pme_gpu_realloc_forces(PmeGpu* pmeGpu)
 {
-    const size_t newForcesSize = pmeGpu->nAtomsAlloc * DIM;
+    const size_t newForcesSize = pmeGpu->nAtomsAlloc;
     GMX_ASSERT(newForcesSize > 0, "Bad number of atoms in PME GPU");
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces, newForcesSize,
-                           &pmeGpu->archSpecific->forcesSize, &pmeGpu->archSpecific->forcesSizeAlloc,
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces,
+                           newForcesSize,
+                           &pmeGpu->archSpecific->forcesSize,
+                           &pmeGpu->archSpecific->forcesSizeAlloc,
                            pmeGpu->archSpecific->deviceContext_);
     pmeGpu->staging.h_forces.reserveWithPadding(pmeGpu->nAtomsAlloc);
     pmeGpu->staging.h_forces.resizeWithPadding(pmeGpu->kernelParams->atoms.nAtoms);
@@ -209,46 +246,64 @@ void pme_gpu_free_forces(const PmeGpu* pmeGpu)
 void pme_gpu_copy_input_forces(PmeGpu* pmeGpu)
 {
     GMX_ASSERT(pmeGpu->kernelParams->atoms.nAtoms > 0, "Bad number of atoms in PME GPU");
-    float* h_forcesFloat = reinterpret_cast<float*>(pmeGpu->staging.h_forces.data());
-    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces, h_forcesFloat, 0,
-                       DIM * pmeGpu->kernelParams->atoms.nAtoms, pmeGpu->archSpecific->pmeStream_,
-                       pmeGpu->settings.transferKind, nullptr);
+    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_forces,
+                       pmeGpu->staging.h_forces.data(),
+                       0,
+                       pmeGpu->kernelParams->atoms.nAtoms,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
 void pme_gpu_copy_output_forces(PmeGpu* pmeGpu)
 {
     GMX_ASSERT(pmeGpu->kernelParams->atoms.nAtoms > 0, "Bad number of atoms in PME GPU");
-    float* h_forcesFloat = reinterpret_cast<float*>(pmeGpu->staging.h_forces.data());
-    copyFromDeviceBuffer(h_forcesFloat, &pmeGpu->kernelParams->atoms.d_forces, 0,
-                         DIM * pmeGpu->kernelParams->atoms.nAtoms, pmeGpu->archSpecific->pmeStream_,
-                         pmeGpu->settings.transferKind, nullptr);
+    copyFromDeviceBuffer(pmeGpu->staging.h_forces.data(),
+                         &pmeGpu->kernelParams->atoms.d_forces,
+                         0,
+                         pmeGpu->kernelParams->atoms.nAtoms,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
 }
 
-void pme_gpu_realloc_and_copy_input_coefficients(PmeGpu* pmeGpu, const float* h_coefficients)
+void pme_gpu_realloc_and_copy_input_coefficients(const PmeGpu* pmeGpu,
+                                                 const float*  h_coefficients,
+                                                 const int     gridIndex)
 {
     GMX_ASSERT(h_coefficients, "Bad host-side charge buffer in PME GPU");
     const size_t newCoefficientsSize = pmeGpu->nAtomsAlloc;
     GMX_ASSERT(newCoefficientsSize > 0, "Bad number of atoms in PME GPU");
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients, newCoefficientsSize,
-                           &pmeGpu->archSpecific->coefficientsSize,
-                           &pmeGpu->archSpecific->coefficientsSizeAlloc,
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex],
+                           newCoefficientsSize,
+                           &pmeGpu->archSpecific->coefficientsSize[gridIndex],
+                           &pmeGpu->archSpecific->coefficientsCapacity[gridIndex],
                            pmeGpu->archSpecific->deviceContext_);
-    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients,
-                       const_cast<float*>(h_coefficients), 0, pmeGpu->kernelParams->atoms.nAtoms,
-                       pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
+    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex],
+                       const_cast<float*>(h_coefficients),
+                       0,
+                       pmeGpu->kernelParams->atoms.nAtoms,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 
     const size_t paddingIndex = pmeGpu->kernelParams->atoms.nAtoms;
     const size_t paddingCount = pmeGpu->nAtomsAlloc - paddingIndex;
     if (paddingCount > 0)
     {
-        clearDeviceBufferAsync(&pmeGpu->kernelParams->atoms.d_coefficients, paddingIndex,
-                               paddingCount, pmeGpu->archSpecific->pmeStream_);
+        clearDeviceBufferAsync(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex],
+                               paddingIndex,
+                               paddingCount,
+                               pmeGpu->archSpecific->pmeStream_);
     }
 }
 
 void pme_gpu_free_coefficients(const PmeGpu* pmeGpu)
 {
-    freeDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        freeDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex]);
+    }
 }
 
 void pme_gpu_realloc_spline_data(PmeGpu* pmeGpu)
@@ -260,10 +315,15 @@ void pme_gpu_realloc_spline_data(PmeGpu* pmeGpu)
     const bool shouldRealloc        = (newSplineDataSize > pmeGpu->archSpecific->splineDataSize);
     int        currentSizeTemp      = pmeGpu->archSpecific->splineDataSize;
     int        currentSizeTempAlloc = pmeGpu->archSpecific->splineDataSizeAlloc;
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_theta, newSplineDataSize, &currentSizeTemp,
-                           &currentSizeTempAlloc, pmeGpu->archSpecific->deviceContext_);
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_dtheta, newSplineDataSize,
-                           &pmeGpu->archSpecific->splineDataSize, &pmeGpu->archSpecific->splineDataSizeAlloc,
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_theta,
+                           newSplineDataSize,
+                           &currentSizeTemp,
+                           &currentSizeTempAlloc,
+                           pmeGpu->archSpecific->deviceContext_);
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_dtheta,
+                           newSplineDataSize,
+                           &pmeGpu->archSpecific->splineDataSize,
+                           &pmeGpu->archSpecific->splineDataSizeAlloc,
                            pmeGpu->archSpecific->deviceContext_);
     // the host side reallocation
     if (shouldRealloc)
@@ -288,7 +348,8 @@ void pme_gpu_realloc_grid_indices(PmeGpu* pmeGpu)
 {
     const size_t newIndicesSize = DIM * pmeGpu->nAtomsAlloc;
     GMX_ASSERT(newIndicesSize > 0, "Bad number of atoms in PME GPU");
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_gridlineIndices, newIndicesSize,
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_gridlineIndices,
+                           newIndicesSize,
                            &pmeGpu->archSpecific->gridlineIndicesSize,
                            &pmeGpu->archSpecific->gridlineIndicesSizeAlloc,
                            pmeGpu->archSpecific->deviceContext_);
@@ -304,51 +365,69 @@ void pme_gpu_free_grid_indices(const PmeGpu* pmeGpu)
 
 void pme_gpu_realloc_grids(PmeGpu* pmeGpu)
 {
-    auto*     kernelParamsPtr = pmeGpu->kernelParams.get();
+    auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+
     const int newRealGridSize = kernelParamsPtr->grid.realGridSizePadded[XX]
                                 * kernelParamsPtr->grid.realGridSizePadded[YY]
                                 * kernelParamsPtr->grid.realGridSizePadded[ZZ];
     const int newComplexGridSize = kernelParamsPtr->grid.complexGridSizePadded[XX]
                                    * kernelParamsPtr->grid.complexGridSizePadded[YY]
                                    * kernelParamsPtr->grid.complexGridSizePadded[ZZ] * 2;
-    // Multiplied by 2 because we count complex grid size for complex numbers, but all allocations/pointers are float
-    if (pmeGpu->archSpecific->performOutOfPlaceFFT)
-    {
-        /* 2 separate grids */
-        reallocateDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid, newComplexGridSize,
-                               &pmeGpu->archSpecific->complexGridSize,
-                               &pmeGpu->archSpecific->complexGridSizeAlloc,
-                               pmeGpu->archSpecific->deviceContext_);
-        reallocateDeviceBuffer(
-                &kernelParamsPtr->grid.d_realGrid, newRealGridSize, &pmeGpu->archSpecific->realGridSize,
-                &pmeGpu->archSpecific->realGridSizeAlloc, pmeGpu->archSpecific->deviceContext_);
-    }
-    else
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
     {
-        /* A single buffer so that any grid will fit */
-        const int newGridsSize = std::max(newRealGridSize, newComplexGridSize);
-        reallocateDeviceBuffer(
-                &kernelParamsPtr->grid.d_realGrid, newGridsSize, &pmeGpu->archSpecific->realGridSize,
-                &pmeGpu->archSpecific->realGridSizeAlloc, pmeGpu->archSpecific->deviceContext_);
-        kernelParamsPtr->grid.d_fourierGrid   = kernelParamsPtr->grid.d_realGrid;
-        pmeGpu->archSpecific->complexGridSize = pmeGpu->archSpecific->realGridSize;
-        // the size might get used later for copying the grid
+        // Multiplied by 2 because we count complex grid size for complex numbers, but all allocations/pointers are float
+        if (pmeGpu->archSpecific->performOutOfPlaceFFT)
+        {
+            /* 2 separate grids */
+            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid[gridIndex],
+                                   newComplexGridSize,
+                                   &pmeGpu->archSpecific->complexGridSize[gridIndex],
+                                   &pmeGpu->archSpecific->complexGridCapacity[gridIndex],
+                                   pmeGpu->archSpecific->deviceContext_);
+            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_realGrid[gridIndex],
+                                   newRealGridSize,
+                                   &pmeGpu->archSpecific->realGridSize[gridIndex],
+                                   &pmeGpu->archSpecific->realGridCapacity[gridIndex],
+                                   pmeGpu->archSpecific->deviceContext_);
+        }
+        else
+        {
+            /* A single buffer so that any grid will fit */
+            const int newGridsSize = std::max(newRealGridSize, newComplexGridSize);
+            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_realGrid[gridIndex],
+                                   newGridsSize,
+                                   &pmeGpu->archSpecific->realGridSize[gridIndex],
+                                   &pmeGpu->archSpecific->realGridCapacity[gridIndex],
+                                   pmeGpu->archSpecific->deviceContext_);
+            kernelParamsPtr->grid.d_fourierGrid[gridIndex] = kernelParamsPtr->grid.d_realGrid[gridIndex];
+            pmeGpu->archSpecific->complexGridSize[gridIndex] =
+                    pmeGpu->archSpecific->realGridSize[gridIndex];
+            // the size might get used later for copying the grid
+        }
     }
 }
 
 void pme_gpu_free_grids(const PmeGpu* pmeGpu)
 {
-    if (pmeGpu->archSpecific->performOutOfPlaceFFT)
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
     {
-        freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_fourierGrid);
+        if (pmeGpu->archSpecific->performOutOfPlaceFFT)
+        {
+            freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_fourierGrid[gridIndex]);
+        }
+        freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex]);
     }
-    freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid);
 }
 
 void pme_gpu_clear_grids(const PmeGpu* pmeGpu)
 {
-    clearDeviceBufferAsync(&pmeGpu->kernelParams->grid.d_realGrid, 0,
-                           pmeGpu->archSpecific->realGridSize, pmeGpu->archSpecific->pmeStream_);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        clearDeviceBufferAsync(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex],
+                               0,
+                               pmeGpu->archSpecific->realGridSize[gridIndex],
+                               pmeGpu->archSpecific->pmeStream_);
+    }
 }
 
 void pme_gpu_realloc_and_copy_fract_shifts(PmeGpu* pmeGpu)
@@ -368,23 +447,27 @@ void pme_gpu_realloc_and_copy_fract_shifts(PmeGpu* pmeGpu)
     const int newFractShiftsSize = cellCount * (nx + ny + nz);
 
     initParamLookupTable(&kernelParamsPtr->grid.d_fractShiftsTable,
-                         &kernelParamsPtr->fractShiftsTableTexture, pmeGpu->common->fsh.data(),
-                         newFractShiftsSize, pmeGpu->archSpecific->deviceContext_);
+                         &kernelParamsPtr->fractShiftsTableTexture,
+                         pmeGpu->common->fsh.data(),
+                         newFractShiftsSize,
+                         pmeGpu->archSpecific->deviceContext_);
 
     initParamLookupTable(&kernelParamsPtr->grid.d_gridlineIndicesTable,
-                         &kernelParamsPtr->gridlineIndicesTableTexture, pmeGpu->common->nn.data(),
-                         newFractShiftsSize, pmeGpu->archSpecific->deviceContext_);
+                         &kernelParamsPtr->gridlineIndicesTableTexture,
+                         pmeGpu->common->nn.data(),
+                         newFractShiftsSize,
+                         pmeGpu->archSpecific->deviceContext_);
 }
 
 void pme_gpu_free_fract_shifts(const PmeGpu* pmeGpu)
 {
     auto* kernelParamsPtr = pmeGpu->kernelParams.get();
-#if GMX_GPU == GMX_GPU_CUDA
+#if GMX_GPU_CUDA
     destroyParamLookupTable(&kernelParamsPtr->grid.d_fractShiftsTable,
-                            kernelParamsPtr->fractShiftsTableTexture);
+                            &kernelParamsPtr->fractShiftsTableTexture);
     destroyParamLookupTable(&kernelParamsPtr->grid.d_gridlineIndicesTable,
-                            kernelParamsPtr->gridlineIndicesTableTexture);
-#elif GMX_GPU == GMX_GPU_OPENCL
+                            &kernelParamsPtr->gridlineIndicesTableTexture);
+#elif GMX_GPU_OPENCL || GMX_GPU_SYCL
     freeDeviceBuffer(&kernelParamsPtr->grid.d_fractShiftsTable);
     freeDeviceBuffer(&kernelParamsPtr->grid.d_gridlineIndicesTable);
 #endif
@@ -395,17 +478,26 @@ bool pme_gpu_stream_query(const PmeGpu* pmeGpu)
     return haveStreamTasksCompleted(pmeGpu->archSpecific->pmeStream_);
 }
 
-void pme_gpu_copy_input_gather_grid(const PmeGpu* pmeGpu, float* h_grid)
+void pme_gpu_copy_input_gather_grid(const PmeGpu* pmeGpu, const float* h_grid, const int gridIndex)
 {
-    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid, h_grid, 0, pmeGpu->archSpecific->realGridSize,
-                       pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
+    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex],
+                       h_grid,
+                       0,
+                       pmeGpu->archSpecific->realGridSize[gridIndex],
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
-void pme_gpu_copy_output_spread_grid(const PmeGpu* pmeGpu, float* h_grid)
+void pme_gpu_copy_output_spread_grid(const PmeGpu* pmeGpu, float* h_grid, const int gridIndex)
 {
-    copyFromDeviceBuffer(h_grid, &pmeGpu->kernelParams->grid.d_realGrid, 0,
-                         pmeGpu->archSpecific->realGridSize, pmeGpu->archSpecific->pmeStream_,
-                         pmeGpu->settings.transferKind, nullptr);
+    copyFromDeviceBuffer(h_grid,
+                         &pmeGpu->kernelParams->grid.d_realGrid[gridIndex],
+                         0,
+                         pmeGpu->archSpecific->realGridSize[gridIndex],
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
     pmeGpu->archSpecific->syncSpreadGridD2H.markEvent(pmeGpu->archSpecific->pmeStream_);
 }
 
@@ -413,13 +505,27 @@ void pme_gpu_copy_output_spread_atom_data(const PmeGpu* pmeGpu)
 {
     const size_t splinesCount    = DIM * pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order;
     auto*        kernelParamsPtr = pmeGpu->kernelParams.get();
-    copyFromDeviceBuffer(pmeGpu->staging.h_dtheta, &kernelParamsPtr->atoms.d_dtheta, 0, splinesCount,
-                         pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
-    copyFromDeviceBuffer(pmeGpu->staging.h_theta, &kernelParamsPtr->atoms.d_theta, 0, splinesCount,
-                         pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
-    copyFromDeviceBuffer(pmeGpu->staging.h_gridlineIndices, &kernelParamsPtr->atoms.d_gridlineIndices,
-                         0, kernelParamsPtr->atoms.nAtoms * DIM, pmeGpu->archSpecific->pmeStream_,
-                         pmeGpu->settings.transferKind, nullptr);
+    copyFromDeviceBuffer(pmeGpu->staging.h_dtheta,
+                         &kernelParamsPtr->atoms.d_dtheta,
+                         0,
+                         splinesCount,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
+    copyFromDeviceBuffer(pmeGpu->staging.h_theta,
+                         &kernelParamsPtr->atoms.d_theta,
+                         0,
+                         splinesCount,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
+    copyFromDeviceBuffer(pmeGpu->staging.h_gridlineIndices,
+                         &kernelParamsPtr->atoms.d_gridlineIndices,
+                         0,
+                         kernelParamsPtr->atoms.nAtoms * DIM,
+                         pmeGpu->archSpecific->pmeStream_,
+                         pmeGpu->settings.transferKind,
+                         nullptr);
 }
 
 void pme_gpu_copy_input_gather_atom_data(const PmeGpu* pmeGpu)
@@ -428,22 +534,40 @@ void pme_gpu_copy_input_gather_atom_data(const PmeGpu* pmeGpu)
     auto*        kernelParamsPtr = pmeGpu->kernelParams.get();
 
     // TODO: could clear only the padding and not the whole thing, but this is a test-exclusive code anyway
-    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_gridlineIndices, 0, pmeGpu->nAtomsAlloc * DIM,
+    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_gridlineIndices,
+                           0,
+                           pmeGpu->nAtomsAlloc * DIM,
                            pmeGpu->archSpecific->pmeStream_);
-    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_dtheta, 0,
+    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_dtheta,
+                           0,
                            pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order * DIM,
                            pmeGpu->archSpecific->pmeStream_);
-    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_theta, 0,
+    clearDeviceBufferAsync(&kernelParamsPtr->atoms.d_theta,
+                           0,
                            pmeGpu->nAtomsAlloc * pmeGpu->common->pme_order * DIM,
                            pmeGpu->archSpecific->pmeStream_);
 
-    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_dtheta, pmeGpu->staging.h_dtheta, 0, splinesCount,
-                       pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
-    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_theta, pmeGpu->staging.h_theta, 0, splinesCount,
-                       pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
-    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_gridlineIndices, pmeGpu->staging.h_gridlineIndices,
-                       0, kernelParamsPtr->atoms.nAtoms * DIM, pmeGpu->archSpecific->pmeStream_,
-                       pmeGpu->settings.transferKind, nullptr);
+    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_dtheta,
+                       pmeGpu->staging.h_dtheta,
+                       0,
+                       splinesCount,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
+    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_theta,
+                       pmeGpu->staging.h_theta,
+                       0,
+                       splinesCount,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
+    copyToDeviceBuffer(&kernelParamsPtr->atoms.d_gridlineIndices,
+                       pmeGpu->staging.h_gridlineIndices,
+                       0,
+                       kernelParamsPtr->atoms.nAtoms * DIM,
+                       pmeGpu->archSpecific->pmeStream_,
+                       pmeGpu->settings.transferKind,
+                       nullptr);
 }
 
 void pme_gpu_sync_spread_grid(const PmeGpu* pmeGpu)
@@ -459,12 +583,6 @@ void pme_gpu_sync_spread_grid(const PmeGpu* pmeGpu)
  */
 static void pme_gpu_init_internal(PmeGpu* pmeGpu, const DeviceContext& deviceContext, const DeviceStream& deviceStream)
 {
-#if GMX_GPU == GMX_GPU_CUDA
-    // Prepare to use the device that this PME task was assigned earlier.
-    // Other entities, such as CUDA timing events, are known to implicitly use the device context.
-    CU_RET_ERR(cudaSetDevice(deviceContext.deviceInfo().id), "Switching to PME CUDA device");
-#endif
-
     /* Allocate the target-specific structures */
     pmeGpu->archSpecific.reset(new PmeGpuSpecific(deviceContext, deviceStream));
     pmeGpu->kernelParams.reset(new PmeGpuKernelParams());
@@ -475,11 +593,15 @@ static void pme_gpu_init_internal(PmeGpu* pmeGpu, const DeviceContext& deviceCon
      * TODO: PME could also try to pick up nice grid sizes (with factors of 2, 3, 5, 7).
      */
 
-#if GMX_GPU == GMX_GPU_CUDA
-    pmeGpu->maxGridWidthX = deviceContext.deviceInfo().prop.maxGridSize[0];
-#elif GMX_GPU == GMX_GPU_OPENCL
-    pmeGpu->maxGridWidthX = INT32_MAX / 2;
+#if GMX_GPU_CUDA
+    pmeGpu->kernelParams->usePipeline       = char(false);
+    pmeGpu->kernelParams->pipelineAtomStart = 0;
+    pmeGpu->kernelParams->pipelineAtomEnd   = 0;
+    pmeGpu->maxGridWidthX                   = deviceContext.deviceInfo().prop.maxGridSize[0];
+#else
+    // Use this path for any non-CUDA GPU acceleration
     // TODO: is there no really global work size limit in OpenCL?
+    pmeGpu->maxGridWidthX = INT32_MAX / 2;
 #endif
 }
 
@@ -488,9 +610,46 @@ void pme_gpu_reinit_3dfft(const PmeGpu* pmeGpu)
     if (pme_gpu_settings(pmeGpu).performGPUFFT)
     {
         pmeGpu->archSpecific->fftSetup.resize(0);
-        for (int i = 0; i < pmeGpu->common->ngrids; i++)
+        const bool         performOutOfPlaceFFT      = pmeGpu->archSpecific->performOutOfPlaceFFT;
+        const bool         allocateGrid              = false;
+        MPI_Comm           comm                      = MPI_COMM_NULL;
+        std::array<int, 1> gridOffsetsInXForEachRank = { 0 };
+        std::array<int, 1> gridOffsetsInYForEachRank = { 0 };
+#if GMX_GPU_CUDA
+        const gmx::FftBackend backend = gmx::FftBackend::Cufft;
+#elif GMX_GPU_OPENCL
+        const gmx::FftBackend backend = gmx::FftBackend::Ocl;
+#elif GMX_GPU_SYCL
+#    if GMX_SYCL_DPCPP && GMX_FFT_MKL
+        const gmx::FftBackend backend = gmx::FftBackend::SyclMkl;
+#    elif GMX_SYCL_HIPSYCL
+        const gmx::FftBackend backend = gmx::FftBackend::SyclRocfft;
+#    else
+        const gmx::FftBackend backend = gmx::FftBackend::Sycl;
+#    endif
+#else
+        GMX_RELEASE_ASSERT(false, "Unknown GPU backend");
+        const gmx::FftBackend backend = gmx::FftBackend::Count;
+#endif
+
+        PmeGpuGridParams& grid = pme_gpu_get_kernel_params_base_ptr(pmeGpu)->grid;
+        for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
         {
-            pmeGpu->archSpecific->fftSetup.push_back(std::make_unique<GpuParallel3dFft>(pmeGpu));
+            pmeGpu->archSpecific->fftSetup.push_back(
+                    std::make_unique<gmx::Gpu3dFft>(backend,
+                                                    allocateGrid,
+                                                    comm,
+                                                    gridOffsetsInXForEachRank,
+                                                    gridOffsetsInYForEachRank,
+                                                    grid.realGridSize[ZZ],
+                                                    performOutOfPlaceFFT,
+                                                    pmeGpu->archSpecific->deviceContext_,
+                                                    pmeGpu->archSpecific->pmeStream_,
+                                                    grid.realGridSize,
+                                                    grid.realGridSizePadded,
+                                                    grid.complexGridSizePadded,
+                                                    &(grid.d_realGrid[gridIndex]),
+                                                    &(grid.d_fourierGrid[gridIndex])));
         }
     }
 }
@@ -500,26 +659,62 @@ void pme_gpu_destroy_3dfft(const PmeGpu* pmeGpu)
     pmeGpu->archSpecific->fftSetup.resize(0);
 }
 
-void pme_gpu_getEnergyAndVirial(const gmx_pme_t& pme, PmeOutput* output)
+void pme_gpu_getEnergyAndVirial(const gmx_pme_t& pme, const float lambda, PmeOutput* output)
 {
     const PmeGpu* pmeGpu = pme.gpu;
-    for (int j = 0; j < c_virialAndEnergyCount; j++)
-    {
-        GMX_ASSERT(std::isfinite(pmeGpu->staging.h_virialAndEnergy[j]),
-                   "PME GPU produces incorrect energy/virial.");
-    }
-
-    unsigned int j                 = 0;
-    output->coulombVirial_[XX][XX] = 0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[YY][YY] = 0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[ZZ][ZZ] = 0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[XX][YY] = output->coulombVirial_[YY][XX] =
-            0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[XX][ZZ] = output->coulombVirial_[ZZ][XX] =
-            0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[YY][ZZ] = output->coulombVirial_[ZZ][YY] =
-            0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombEnergy_ = 0.5F * pmeGpu->staging.h_virialAndEnergy[j++];
+
+    GMX_ASSERT(lambda == 1.0 || pmeGpu->common->ngrids == 2,
+               "Invalid combination of lambda and number of grids");
+
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        for (int j = 0; j < c_virialAndEnergyCount; j++)
+        {
+            GMX_ASSERT(std::isfinite(pmeGpu->staging.h_virialAndEnergy[gridIndex][j]),
+                       "PME GPU produces incorrect energy/virial.");
+        }
+    }
+    for (int dim1 = 0; dim1 < DIM; dim1++)
+    {
+        for (int dim2 = 0; dim2 < DIM; dim2++)
+        {
+            output->coulombVirial_[dim1][dim2] = 0;
+        }
+    }
+    output->coulombEnergy_ = 0;
+    float scale            = 1.0;
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        if (pmeGpu->common->ngrids == 2)
+        {
+            scale = gridIndex == 0 ? (1.0 - lambda) : lambda;
+        }
+        output->coulombVirial_[XX][XX] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][0];
+        output->coulombVirial_[YY][YY] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][1];
+        output->coulombVirial_[ZZ][ZZ] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][2];
+        output->coulombVirial_[XX][YY] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][3];
+        output->coulombVirial_[YY][XX] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][3];
+        output->coulombVirial_[XX][ZZ] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][4];
+        output->coulombVirial_[ZZ][XX] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][4];
+        output->coulombVirial_[YY][ZZ] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][5];
+        output->coulombVirial_[ZZ][YY] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][5];
+        output->coulombEnergy_ += scale * 0.5F * pmeGpu->staging.h_virialAndEnergy[gridIndex][6];
+    }
+    if (pmeGpu->common->ngrids > 1)
+    {
+        output->coulombDvdl_ = 0.5F
+                               * (pmeGpu->staging.h_virialAndEnergy[FEP_STATE_B][6]
+                                  - pmeGpu->staging.h_virialAndEnergy[FEP_STATE_A][6]);
+    }
 }
 
 /*! \brief Sets the force-related members in \p output
@@ -536,7 +731,7 @@ static void pme_gpu_getForceOutput(PmeGpu* pmeGpu, PmeOutput* output)
     }
 }
 
-PmeOutput pme_gpu_getOutput(const gmx_pme_t& pme, const bool computeEnergyAndVirial)
+PmeOutput pme_gpu_getOutput(const gmx_pme_t& pme, const bool computeEnergyAndVirial, const real lambdaQ)
 {
     PmeGpu* pmeGpu = pme.gpu;
 
@@ -548,7 +743,7 @@ PmeOutput pme_gpu_getOutput(const gmx_pme_t& pme, const bool computeEnergyAndVir
     {
         if (pme_gpu_settings(pmeGpu).performGPUSolve)
         {
-            pme_gpu_getEnergyAndVirial(pme, &output);
+            pme_gpu_getEnergyAndVirial(pme, lambdaQ, &output);
         }
         else
         {
@@ -590,9 +785,13 @@ void pme_gpu_update_input_box(PmeGpu gmx_unused* pmeGpu, const matrix gmx_unused
 static void pme_gpu_reinit_grids(PmeGpu* pmeGpu)
 {
     auto* kernelParamsPtr = pme_gpu_get_kernel_params_base_ptr(pmeGpu);
+
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+
     kernelParamsPtr->grid.ewaldFactor =
             (M_PI * M_PI) / (pmeGpu->common->ewaldcoeff_q * pmeGpu->common->ewaldcoeff_q);
-
     /* The grid size variants */
     for (int i = 0; i < DIM; i++)
     {
@@ -613,14 +812,17 @@ static void pme_gpu_reinit_grids(PmeGpu* pmeGpu)
         kernelParamsPtr->grid.realGridSizePadded[ZZ] =
                 (kernelParamsPtr->grid.realGridSize[ZZ] / 2 + 1) * 2;
     }
-
     /* GPU FFT: n real elements correspond to (n / 2 + 1) complex elements in minor dimension */
     kernelParamsPtr->grid.complexGridSize[ZZ] /= 2;
     kernelParamsPtr->grid.complexGridSize[ZZ]++;
     kernelParamsPtr->grid.complexGridSizePadded[ZZ] = kernelParamsPtr->grid.complexGridSize[ZZ];
 
     pme_gpu_realloc_and_copy_fract_shifts(pmeGpu);
-    pme_gpu_realloc_and_copy_bspline_values(pmeGpu);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        pme_gpu_realloc_and_copy_bspline_values(pmeGpu, gridIndex);
+    }
+
     pme_gpu_realloc_grids(pmeGpu);
     pme_gpu_reinit_3dfft(pmeGpu);
 }
@@ -641,7 +843,7 @@ static void pme_gpu_copy_common_data_from(const gmx_pme_t* pme)
     /* TODO: Consider refactoring the CPU PME code to use the same structure,
      * so that this function becomes 2 lines */
     PmeGpu* pmeGpu               = pme->gpu;
-    pmeGpu->common->ngrids       = pme->ngrids;
+    pmeGpu->common->ngrids       = pme->bFEP_q ? 2 : 1;
     pmeGpu->common->epsilon_r    = pme->epsilon_r;
     pmeGpu->common->ewaldcoeff_q = pme->ewaldcoeff_q;
     pmeGpu->common->nk[XX]       = pme->nkx;
@@ -667,7 +869,7 @@ static void pme_gpu_copy_common_data_from(const gmx_pme_t* pme)
     pmeGpu->common->nn.insert(pmeGpu->common->nn.end(), pme->nnz, pme->nnz + cellCount * pme->nkz);
     pmeGpu->common->runMode       = pme->runMode;
     pmeGpu->common->isRankPmeOnly = !pme->bPPnode;
-    pmeGpu->common->boxScaler     = pme->boxScaler;
+    pmeGpu->common->boxScaler     = pme->boxScaler.get();
 }
 
 /*! \libinternal \brief
@@ -677,7 +879,7 @@ static void pme_gpu_copy_common_data_from(const gmx_pme_t* pme)
  */
 static void pme_gpu_select_best_performing_pme_spreadgather_kernels(PmeGpu* pmeGpu)
 {
-    if (pmeGpu->kernelParams->atoms.nAtoms > c_pmeGpuPerformanceAtomLimit && (GMX_GPU == GMX_GPU_CUDA))
+    if (GMX_GPU_CUDA && pmeGpu->kernelParams->atoms.nAtoms > c_pmeGpuPerformanceAtomLimit)
     {
         pmeGpu->settings.threadsPerAtom     = ThreadsPerAtom::Order;
         pmeGpu->settings.recalculateSplines = true;
@@ -725,14 +927,14 @@ static void pme_gpu_init(gmx_pme_t*           pme,
     pmeGpu->initializedClfftLibrary_ = std::make_unique<gmx::ClfftInitializer>();
 
     pme_gpu_init_internal(pmeGpu, deviceContext, deviceStream);
-    pme_gpu_alloc_energy_virial(pmeGpu);
 
     pme_gpu_copy_common_data_from(pme);
+    pme_gpu_alloc_energy_virial(pmeGpu);
 
     GMX_ASSERT(pmeGpu->common->epsilon_r != 0.0F, "PME GPU: bad electrostatic coefficient");
 
     auto* kernelParamsPtr               = pme_gpu_get_kernel_params_base_ptr(pmeGpu);
-    kernelParamsPtr->constants.elFactor = ONE_4PI_EPS0 / pmeGpu->common->epsilon_r;
+    kernelParamsPtr->constants.elFactor = gmx::c_one4PiEps0 / pmeGpu->common->epsilon_r;
 }
 
 void pme_gpu_get_real_grid_sizes(const PmeGpu* pmeGpu, gmx::IVec* gridSize, gmx::IVec* paddedGridSize)
@@ -804,7 +1006,7 @@ void pme_gpu_destroy(PmeGpu* pmeGpu)
     delete pmeGpu;
 }
 
-void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* charges)
+void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* chargesA, const real* chargesB)
 {
     auto* kernelParamsPtr         = pme_gpu_get_kernel_params_base_ptr(pmeGpu);
     kernelParamsPtr->atoms.nAtoms = nAtoms;
@@ -817,8 +1019,25 @@ void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* charges)
     GMX_RELEASE_ASSERT(false, "Only single precision supported");
     GMX_UNUSED_VALUE(charges);
 #else
-    pme_gpu_realloc_and_copy_input_coefficients(pmeGpu, reinterpret_cast<const float*>(charges));
+    int gridIndex = 0;
     /* Could also be checked for haveToRealloc, but the copy always needs to be performed */
+    pme_gpu_realloc_and_copy_input_coefficients(pmeGpu, reinterpret_cast<const float*>(chargesA), gridIndex);
+    gridIndex++;
+    if (chargesB != nullptr)
+    {
+        pme_gpu_realloc_and_copy_input_coefficients(
+                pmeGpu, reinterpret_cast<const float*>(chargesB), gridIndex);
+    }
+    else
+    {
+        /* Fill the second set of coefficients with chargesA as well to be able to avoid
+         * conditionals in the GPU kernels */
+        /* FIXME: This should be avoided by making a separate templated version of the
+         * relevant kernel(s) (probably only pme_gather_kernel). That would require a
+         * reduction of the current number of templated parameters of that kernel. */
+        pme_gpu_realloc_and_copy_input_coefficients(
+                pmeGpu, reinterpret_cast<const float*>(chargesA), gridIndex);
+    }
 #endif
 
     if (haveToRealloc)
@@ -835,23 +1054,23 @@ void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* charges)
  * In CUDA result can be nullptr stub, per GpuRegionTimer implementation.
  *
  * \param[in] pmeGpu         The PME GPU data structure.
- * \param[in] PMEStageId     The PME GPU stage gtPME_ index from the enum in src/gromacs/timing/gpu_timing.h
+ * \param[in] pmeStageId     The PME GPU stage gtPME_ index from the enum in src/gromacs/timing/gpu_timing.h
  */
-static CommandEvent* pme_gpu_fetch_timing_event(const PmeGpu* pmeGpu, size_t PMEStageId)
+static CommandEvent* pme_gpu_fetch_timing_event(const PmeGpu* pmeGpu, PmeStage pmeStageId)
 {
     CommandEvent* timingEvent = nullptr;
     if (pme_gpu_timings_enabled(pmeGpu))
     {
-        GMX_ASSERT(PMEStageId < pmeGpu->archSpecific->timingEvents.size(),
-                   "Wrong PME GPU timing event index");
-        timingEvent = pmeGpu->archSpecific->timingEvents[PMEStageId].fetchNextEvent();
+        GMX_ASSERT(pmeStageId < PmeStage::Count, "Wrong PME GPU timing event index");
+        timingEvent = pmeGpu->archSpecific->timingEvents[pmeStageId].fetchNextEvent();
     }
     return timingEvent;
 }
 
-void pme_gpu_3dfft(const PmeGpu* pmeGpu, gmx_fft_direction dir, int grid_index)
+void pme_gpu_3dfft(const PmeGpu* pmeGpu, gmx_fft_direction dir, const int grid_index)
 {
-    int timerId = (dir == GMX_FFT_REAL_TO_COMPLEX) ? gtPME_FFT_R2C : gtPME_FFT_C2R;
+    PmeStage timerId = (dir == GMX_FFT_REAL_TO_COMPLEX) ? PmeStage::FftTransformR2C
+                                                        : PmeStage::FftTransformC2R;
 
     pme_gpu_start_timing(pmeGpu, timerId);
     pmeGpu->archSpecific->fftSetup[grid_index]->perform3dFft(
@@ -881,32 +1100,64 @@ std::pair<int, int> inline pmeGpuCreateGrid(const PmeGpu* pmeGpu, int blockCount
  * \param[in]  pmeGpu                   The PME GPU structure.
  * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  writeSplinesToGlobal     bool controlling if we should write spline data to global memory
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
-static auto selectSplineAndSpreadKernelPtr(const PmeGpu* pmeGpu, ThreadsPerAtom threadsPerAtom, bool writeSplinesToGlobal)
+static auto selectSplineAndSpreadKernelPtr(const PmeGpu*  pmeGpu,
+                                           ThreadsPerAtom threadsPerAtom,
+                                           bool           writeSplinesToGlobal,
+                                           const int      numGrids)
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (writeSplinesToGlobal)
     {
         if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplines;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesSingle;
+            }
         }
     }
     else
     {
         if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelSingle;
+            }
         }
     }
 
@@ -919,12 +1170,14 @@ static auto selectSplineAndSpreadKernelPtr(const PmeGpu* pmeGpu, ThreadsPerAtom
  * \param[in]  pmeGpu                   The PME GPU structure.
  * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  writeSplinesToGlobal     bool controlling if we should write spline data to global memory
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
-static auto selectSplineKernelPtr(const PmeGpu*  pmeGpu,
-                                  ThreadsPerAtom threadsPerAtom,
-                                  bool gmx_unused writeSplinesToGlobal)
+static auto selectSplineKernelPtr(const PmeGpu*   pmeGpu,
+                                  ThreadsPerAtom  threadsPerAtom,
+                                  bool gmx_unused writeSplinesToGlobal,
+                                  const int       numGrids)
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     GMX_ASSERT(
@@ -933,11 +1186,25 @@ static auto selectSplineKernelPtr(const PmeGpu*  pmeGpu,
 
     if (threadsPerAtom == ThreadsPerAtom::Order)
     {
-        kernelPtr = pmeGpu->programHandle_->impl_->splineKernelThPerAtom4;
+        if (numGrids == 2)
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelThPerAtom4Dual;
+        }
+        else
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelThPerAtom4Single;
+        }
     }
     else
     {
-        kernelPtr = pmeGpu->programHandle_->impl_->splineKernel;
+        if (numGrids == 2)
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelDual;
+        }
+        else
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelSingle;
+        }
     }
     return kernelPtr;
 }
@@ -948,21 +1215,39 @@ static auto selectSplineKernelPtr(const PmeGpu*  pmeGpu,
  * \param[in]  pmeGpu                   The PME GPU structure.
  * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  writeSplinesToGlobal     bool controlling if we should write spline data to global memory
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
-static auto selectSpreadKernelPtr(const PmeGpu* pmeGpu, ThreadsPerAtom threadsPerAtom, bool writeSplinesToGlobal)
+static auto selectSpreadKernelPtr(const PmeGpu*  pmeGpu,
+                                  ThreadsPerAtom threadsPerAtom,
+                                  bool           writeSplinesToGlobal,
+                                  const int      numGrids)
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (writeSplinesToGlobal)
     {
         if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->spreadKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelSingle;
+            }
         }
     }
     else
@@ -971,26 +1256,46 @@ static auto selectSpreadKernelPtr(const PmeGpu* pmeGpu, ThreadsPerAtom threadsPe
            using the spline and spread Kernel */
         if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelSingle;
+            }
         }
     }
     return kernelPtr;
 }
 
-void pme_gpu_spread(const PmeGpu*         pmeGpu,
-                    GpuEventSynchronizer* xReadyOnDevice,
-                    int gmx_unused gridIndex,
-                    real*          h_grid,
-                    bool           computeSplines,
-                    bool           spreadCharges)
+void pme_gpu_spread(const PmeGpu*                  pmeGpu,
+                    GpuEventSynchronizer*          xReadyOnDevice,
+                    real**                         h_grids,
+                    bool                           computeSplines,
+                    bool                           spreadCharges,
+                    const real                     lambda,
+                    const bool                     useGpuDirectComm,
+                    gmx::PmeCoordinateReceiverGpu* pmeCoordinateReceiverGpu)
 {
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+
     GMX_ASSERT(computeSplines || spreadCharges,
                "PME spline/spread kernel has invalid input (nothing to do)");
-    const auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+    auto* kernelParamsPtr = pmeGpu->kernelParams.get();
     GMX_ASSERT(kernelParamsPtr->atoms.nAtoms > 0, "No atom data in PME GPU spread");
 
     const size_t blockSize = pmeGpu->programHandle_->impl_->spreadWorkGroupSize;
@@ -1001,11 +1306,12 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
     const int  threadsPerAtom =
             (pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::Order ? order : order * order);
     const bool recalculateSplines = pmeGpu->settings.recalculateSplines;
-#if GMX_GPU == GMX_GPU_OPENCL
-    GMX_ASSERT(pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::OrderSquared,
+
+    GMX_ASSERT(!GMX_GPU_OPENCL || pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::OrderSquared,
                "Only 16 threads per atom supported in OpenCL");
-    GMX_ASSERT(!recalculateSplines, "Recalculating splines not supported in OpenCL");
-#endif
+    GMX_ASSERT(!GMX_GPU_OPENCL || !recalculateSplines,
+               "Recalculating splines not supported in OpenCL");
+
     const int atomsPerBlock = blockSize / threadsPerAtom;
 
     // TODO: pick smaller block size in runtime if needed
@@ -1018,11 +1324,12 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
                "inconsistent atom data padding vs. spreading block size");
 
     // Ensure that coordinates are ready on the device before launching spread;
-    // only needed with CUDA on PP+PME ranks, not on separate PME ranks, in unit tests
-    // nor in OpenCL as these cases use a single stream (hence xReadyOnDevice == nullptr).
-    GMX_ASSERT(xReadyOnDevice != nullptr || (GMX_GPU != GMX_GPU_CUDA)
-                       || pmeGpu->common->isRankPmeOnly || pme_gpu_settings(pmeGpu).copyAllOutputs,
+    // only needed on PP+PME ranks, not on separate PME ranks, in unit tests
+    // as these cases use a single stream (hence xReadyOnDevice == nullptr).
+    GMX_ASSERT(xReadyOnDevice != nullptr || pmeGpu->common->isRankPmeOnly
+                       || pme_gpu_settings(pmeGpu).copyAllOutputs,
                "Need a valid coordinate synchronizer on PP+PME ranks with CUDA.");
+
     if (xReadyOnDevice)
     {
         xReadyOnDevice->enqueueWaitEvent(pmeGpu->archSpecific->pmeStream_);
@@ -1031,6 +1338,15 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
     const int blockCount = pmeGpu->nAtomsAlloc / atomsPerBlock;
     auto      dimGrid    = pmeGpuCreateGrid(pmeGpu, blockCount);
 
+    if (pmeGpu->common->ngrids == 1)
+    {
+        kernelParamsPtr->current.scale = 1.0;
+    }
+    else
+    {
+        kernelParamsPtr->current.scale = 1.0 - lambda;
+    }
+
     KernelLaunchConfig config;
     config.blockSize[0] = order;
     config.blockSize[1] = (pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::Order ? 1 : order);
@@ -1038,53 +1354,145 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
     config.gridSize[0]  = dimGrid.first;
     config.gridSize[1]  = dimGrid.second;
 
-    int                                timingId;
+    PmeStage                           timingId;
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
+    const bool writeGlobalOrSaveSplines          = writeGlobal || (!recalculateSplines);
     if (computeSplines)
     {
         if (spreadCharges)
         {
-            timingId  = gtPME_SPLINEANDSPREAD;
-            kernelPtr = selectSplineAndSpreadKernelPtr(pmeGpu, pmeGpu->settings.threadsPerAtom,
-                                                       writeGlobal || (!recalculateSplines));
+            timingId  = PmeStage::SplineAndSpread;
+            kernelPtr = selectSplineAndSpreadKernelPtr(pmeGpu,
+                                                       pmeGpu->settings.threadsPerAtom,
+                                                       writeGlobalOrSaveSplines,
+                                                       pmeGpu->common->ngrids);
         }
         else
         {
-            timingId  = gtPME_SPLINE;
-            kernelPtr = selectSplineKernelPtr(pmeGpu, pmeGpu->settings.threadsPerAtom,
-                                              writeGlobal || (!recalculateSplines));
+            timingId  = PmeStage::Spline;
+            kernelPtr = selectSplineKernelPtr(pmeGpu,
+                                              pmeGpu->settings.threadsPerAtom,
+                                              writeGlobalOrSaveSplines,
+                                              pmeGpu->common->ngrids);
         }
     }
     else
     {
-        timingId  = gtPME_SPREAD;
-        kernelPtr = selectSpreadKernelPtr(pmeGpu, pmeGpu->settings.threadsPerAtom,
-                                          writeGlobal || (!recalculateSplines));
+        timingId  = PmeStage::Spread;
+        kernelPtr = selectSpreadKernelPtr(
+                pmeGpu, pmeGpu->settings.threadsPerAtom, writeGlobalOrSaveSplines, pmeGpu->common->ngrids);
     }
 
 
     pme_gpu_start_timing(pmeGpu, timingId);
     auto* timingEvent = pme_gpu_fetch_timing_event(pmeGpu, timingId);
+
+    kernelParamsPtr->usePipeline = char(computeSplines && spreadCharges && useGpuDirectComm
+                                        && (pmeCoordinateReceiverGpu->ppCommNumSenderRanks() > 1)
+                                        && !writeGlobalOrSaveSplines);
+    if (kernelParamsPtr->usePipeline != 0)
+    {
+        int numStagesInPipeline = pmeCoordinateReceiverGpu->ppCommNumSenderRanks();
+
+        for (int i = 0; i < numStagesInPipeline; i++)
+        {
+            int senderRank;
+            if (useGpuDirectComm)
+            {
+                senderRank = pmeCoordinateReceiverGpu->synchronizeOnCoordinatesFromPpRank(
+                        i, *(pmeCoordinateReceiverGpu->ppCommStream(i)));
+            }
+            else
+            {
+                senderRank = i;
+            }
+
+            // set kernel configuration options specific to this stage of the pipeline
+            std::tie(kernelParamsPtr->pipelineAtomStart, kernelParamsPtr->pipelineAtomEnd) =
+                    pmeCoordinateReceiverGpu->ppCommAtomRange(senderRank);
+            const int blockCount       = static_cast<int>(std::ceil(
+                    static_cast<float>(kernelParamsPtr->pipelineAtomEnd - kernelParamsPtr->pipelineAtomStart)
+                    / atomsPerBlock));
+            auto      dimGrid          = pmeGpuCreateGrid(pmeGpu, blockCount);
+            config.gridSize[0]         = dimGrid.first;
+            config.gridSize[1]         = dimGrid.second;
+            DeviceStream* launchStream = pmeCoordinateReceiverGpu->ppCommStream(senderRank);
+
+
 #if c_canEmbedBuffers
-    const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
+            const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
 #else
-    const auto kernelArgs = prepareGpuKernelArguments(
-            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->atoms.d_theta,
-            &kernelParamsPtr->atoms.d_dtheta, &kernelParamsPtr->atoms.d_gridlineIndices,
-            &kernelParamsPtr->grid.d_realGrid, &kernelParamsPtr->grid.d_fractShiftsTable,
-            &kernelParamsPtr->grid.d_gridlineIndicesTable, &kernelParamsPtr->atoms.d_coefficients,
-            &kernelParamsPtr->atoms.d_coordinates);
+            const auto kernelArgs =
+                    prepareGpuKernelArguments(kernelPtr,
+                                              config,
+                                              kernelParamsPtr,
+                                              &kernelParamsPtr->atoms.d_theta,
+                                              &kernelParamsPtr->atoms.d_dtheta,
+                                              &kernelParamsPtr->atoms.d_gridlineIndices,
+                                              &kernelParamsPtr->grid.d_realGrid[FEP_STATE_A],
+                                              &kernelParamsPtr->grid.d_realGrid[FEP_STATE_B],
+                                              &kernelParamsPtr->grid.d_fractShiftsTable,
+                                              &kernelParamsPtr->grid.d_gridlineIndicesTable,
+                                              &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_A],
+                                              &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_B],
+                                              &kernelParamsPtr->atoms.d_coordinates);
 #endif
 
-    launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent,
-                    "PME spline/spread", kernelArgs);
+            launchGpuKernel(kernelPtr, config, *launchStream, timingEvent, "PME spline/spread", kernelArgs);
+        }
+        // Set dependencies for PME stream on all pipeline streams
+        for (int i = 0; i < pmeCoordinateReceiverGpu->ppCommNumSenderRanks(); i++)
+        {
+            GpuEventSynchronizer event;
+            event.markEvent(*(pmeCoordinateReceiverGpu->ppCommStream(i)));
+            event.enqueueWaitEvent(pmeGpu->archSpecific->pmeStream_);
+        }
+    }
+    else // pipelining is not in use
+    {
+        if (useGpuDirectComm) // Sync all PME-PP communications to PME stream
+        {
+            pmeCoordinateReceiverGpu->synchronizeOnCoordinatesFromAllPpRanks(pmeGpu->archSpecific->pmeStream_);
+        }
+
+#if c_canEmbedBuffers
+        const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
+#else
+        const auto kernelArgs =
+                prepareGpuKernelArguments(kernelPtr,
+                                          config,
+                                          kernelParamsPtr,
+                                          &kernelParamsPtr->atoms.d_theta,
+                                          &kernelParamsPtr->atoms.d_dtheta,
+                                          &kernelParamsPtr->atoms.d_gridlineIndices,
+                                          &kernelParamsPtr->grid.d_realGrid[FEP_STATE_A],
+                                          &kernelParamsPtr->grid.d_realGrid[FEP_STATE_B],
+                                          &kernelParamsPtr->grid.d_fractShiftsTable,
+                                          &kernelParamsPtr->grid.d_gridlineIndicesTable,
+                                          &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_A],
+                                          &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_B],
+                                          &kernelParamsPtr->atoms.d_coordinates);
+#endif
+
+        launchGpuKernel(kernelPtr,
+                        config,
+                        pmeGpu->archSpecific->pmeStream_,
+                        timingEvent,
+                        "PME spline/spread",
+                        kernelArgs);
+    }
+
     pme_gpu_stop_timing(pmeGpu, timingId);
 
     const auto& settings    = pmeGpu->settings;
     const bool copyBackGrid = spreadCharges && (!settings.performGPUFFT || settings.copyAllOutputs);
     if (copyBackGrid)
     {
-        pme_gpu_copy_output_spread_grid(pmeGpu, h_grid);
+        for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+        {
+            float* h_grid = h_grids[gridIndex];
+            pme_gpu_copy_output_spread_grid(pmeGpu, h_grid, gridIndex);
+        }
     }
     const bool copyBackAtomData =
             computeSplines && (!settings.performGPUGather || settings.copyAllOutputs);
@@ -1094,8 +1502,18 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
     }
 }
 
-void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrdering, bool computeEnergyAndVirial)
+void pme_gpu_solve(const PmeGpu* pmeGpu,
+                   const int     gridIndex,
+                   t_complex*    h_grid,
+                   GridOrdering  gridOrdering,
+                   bool          computeEnergyAndVirial)
 {
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+    GMX_ASSERT(gridIndex < pmeGpu->common->ngrids,
+               "Invalid combination of gridIndex and number of grids");
+
     const auto& settings               = pmeGpu->settings;
     const bool  copyInputAndOutputGrid = !settings.performGPUFFT || settings.copyAllOutputs;
 
@@ -1104,9 +1522,13 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
     float* h_gridFloat = reinterpret_cast<float*>(h_grid);
     if (copyInputAndOutputGrid)
     {
-        copyToDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid, h_gridFloat, 0,
-                           pmeGpu->archSpecific->complexGridSize, pmeGpu->archSpecific->pmeStream_,
-                           pmeGpu->settings.transferKind, nullptr);
+        copyToDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid[gridIndex],
+                           h_gridFloat,
+                           0,
+                           pmeGpu->archSpecific->complexGridSize[gridIndex],
+                           pmeGpu->archSpecific->pmeStream_,
+                           pmeGpu->settings.transferKind,
+                           nullptr);
     }
 
     int majorDim = -1, middleDim = -1, minorDim = -1;
@@ -1144,7 +1566,7 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
     const int warpSize  = pmeGpu->programHandle_->warpSize();
     const int blockSize = (cellsPerBlock + warpSize - 1) / warpSize * warpSize;
 
-    static_assert(GMX_GPU != GMX_GPU_CUDA || c_solveMaxWarpsPerBlock / 2 >= 4,
+    static_assert(!GMX_GPU_CUDA || c_solveMaxWarpsPerBlock / 2 >= 4,
                   "The CUDA solve energy kernels needs at least 4 warps. "
                   "Here we launch at least half of the max warps.");
 
@@ -1156,17 +1578,33 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
                          / gridLinesPerBlock;
     config.gridSize[2] = pmeGpu->kernelParams->grid.complexGridSize[majorDim];
 
-    int                                timingId  = gtPME_SOLVE;
+    PmeStage                           timingId  = PmeStage::Solve;
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (gridOrdering == GridOrdering::YZX)
     {
-        kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveYZXEnergyKernel
-                                           : pmeGpu->programHandle_->impl_->solveYZXKernel;
+        if (gridIndex == 0)
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveYZXEnergyKernelA
+                                               : pmeGpu->programHandle_->impl_->solveYZXKernelA;
+        }
+        else
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveYZXEnergyKernelB
+                                               : pmeGpu->programHandle_->impl_->solveYZXKernelB;
+        }
     }
     else if (gridOrdering == GridOrdering::XYZ)
     {
-        kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveXYZEnergyKernel
-                                           : pmeGpu->programHandle_->impl_->solveXYZKernel;
+        if (gridIndex == 0)
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveXYZEnergyKernelA
+                                               : pmeGpu->programHandle_->impl_->solveXYZKernelA;
+        }
+        else
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveXYZEnergyKernelB
+                                               : pmeGpu->programHandle_->impl_->solveXYZKernelB;
+        }
     }
 
     pme_gpu_start_timing(pmeGpu, timingId);
@@ -1174,26 +1612,37 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
 #if c_canEmbedBuffers
     const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
 #else
-    const auto kernelArgs = prepareGpuKernelArguments(
-            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->grid.d_splineModuli,
-            &kernelParamsPtr->constants.d_virialAndEnergy, &kernelParamsPtr->grid.d_fourierGrid);
+    const auto kernelArgs =
+            prepareGpuKernelArguments(kernelPtr,
+                                      config,
+                                      kernelParamsPtr,
+                                      &kernelParamsPtr->grid.d_splineModuli[gridIndex],
+                                      &kernelParamsPtr->constants.d_virialAndEnergy[gridIndex],
+                                      &kernelParamsPtr->grid.d_fourierGrid[gridIndex]);
 #endif
-    launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent, "PME solve",
-                    kernelArgs);
+    launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent, "PME solve", kernelArgs);
     pme_gpu_stop_timing(pmeGpu, timingId);
 
     if (computeEnergyAndVirial)
     {
-        copyFromDeviceBuffer(pmeGpu->staging.h_virialAndEnergy,
-                             &kernelParamsPtr->constants.d_virialAndEnergy, 0, c_virialAndEnergyCount,
-                             pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
+        copyFromDeviceBuffer(pmeGpu->staging.h_virialAndEnergy[gridIndex],
+                             &kernelParamsPtr->constants.d_virialAndEnergy[gridIndex],
+                             0,
+                             c_virialAndEnergyCount,
+                             pmeGpu->archSpecific->pmeStream_,
+                             pmeGpu->settings.transferKind,
+                             nullptr);
     }
 
     if (copyInputAndOutputGrid)
     {
-        copyFromDeviceBuffer(h_gridFloat, &kernelParamsPtr->grid.d_fourierGrid, 0,
-                             pmeGpu->archSpecific->complexGridSize, pmeGpu->archSpecific->pmeStream_,
-                             pmeGpu->settings.transferKind, nullptr);
+        copyFromDeviceBuffer(h_gridFloat,
+                             &kernelParamsPtr->grid.d_fourierGrid[gridIndex],
+                             0,
+                             pmeGpu->archSpecific->complexGridSize[gridIndex],
+                             pmeGpu->archSpecific->pmeStream_,
+                             pmeGpu->settings.transferKind,
+                             nullptr);
     }
 }
 
@@ -1203,10 +1652,14 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
  * \param[in]  pmeGpu                   The PME GPU structure.
  * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  readSplinesFromGlobal    bool controlling if we should write spline data to global memory
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
-inline auto selectGatherKernelPtr(const PmeGpu* pmeGpu, ThreadsPerAtom threadsPerAtom, bool readSplinesFromGlobal)
+inline auto selectGatherKernelPtr(const PmeGpu*  pmeGpu,
+                                  ThreadsPerAtom threadsPerAtom,
+                                  bool           readSplinesFromGlobal,
+                                  const int      numGrids)
 
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
@@ -1215,34 +1668,70 @@ inline auto selectGatherKernelPtr(const PmeGpu* pmeGpu, ThreadsPerAtom threadsPe
     {
         if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplines;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesSingle;
+            }
         }
     }
     else
     {
         if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->gatherKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelSingle;
+            }
         }
     }
     return kernelPtr;
 }
 
-
-void pme_gpu_gather(PmeGpu* pmeGpu, const float* h_grid)
+void pme_gpu_gather(PmeGpu* pmeGpu, real** h_grids, const float lambda)
 {
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+
     const auto& settings = pmeGpu->settings;
+
     if (!settings.performGPUFFT || settings.copyAllOutputs)
     {
-        pme_gpu_copy_input_gather_grid(pmeGpu, const_cast<float*>(h_grid));
+        for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+        {
+            float* h_grid = const_cast<float*>(h_grids[gridIndex]);
+            pme_gpu_copy_input_gather_grid(pmeGpu, h_grid, gridIndex);
+        }
     }
 
     if (settings.copyAllOutputs)
@@ -1258,11 +1747,12 @@ void pme_gpu_gather(PmeGpu* pmeGpu, const float* h_grid)
     const int threadsPerAtom =
             (pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::Order ? order : order * order);
     const bool recalculateSplines = pmeGpu->settings.recalculateSplines;
-#if GMX_GPU == GMX_GPU_OPENCL
-    GMX_ASSERT(pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::OrderSquared,
+
+    GMX_ASSERT(!GMX_GPU_OPENCL || pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::OrderSquared,
                "Only 16 threads per atom supported in OpenCL");
-    GMX_ASSERT(!recalculateSplines, "Recalculating splines not supported in OpenCL");
-#endif
+    GMX_ASSERT(!GMX_GPU_OPENCL || !recalculateSplines,
+               "Recalculating splines not supported in OpenCL");
+
     const int atomsPerBlock = blockSize / threadsPerAtom;
 
     GMX_ASSERT(!(c_pmeAtomDataBlockSize % atomsPerBlock),
@@ -1280,25 +1770,43 @@ void pme_gpu_gather(PmeGpu* pmeGpu, const float* h_grid)
 
     // TODO test different cache configs
 
-    int                                timingId  = gtPME_GATHER;
-    PmeGpuProgramImpl::PmeKernelHandle kernelPtr = selectGatherKernelPtr(
-            pmeGpu, pmeGpu->settings.threadsPerAtom, readGlobal || (!recalculateSplines));
+    PmeStage                           timingId = PmeStage::Gather;
+    PmeGpuProgramImpl::PmeKernelHandle kernelPtr =
+            selectGatherKernelPtr(pmeGpu,
+                                  pmeGpu->settings.threadsPerAtom,
+                                  readGlobal || (!recalculateSplines),
+                                  pmeGpu->common->ngrids);
     // TODO design kernel selection getters and make PmeGpu a friend of PmeGpuProgramImpl
 
     pme_gpu_start_timing(pmeGpu, timingId);
-    auto*       timingEvent     = pme_gpu_fetch_timing_event(pmeGpu, timingId);
-    const auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+    auto* timingEvent     = pme_gpu_fetch_timing_event(pmeGpu, timingId);
+    auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+    if (pmeGpu->common->ngrids == 1)
+    {
+        kernelParamsPtr->current.scale = 1.0;
+    }
+    else
+    {
+        kernelParamsPtr->current.scale = 1.0 - lambda;
+    }
+
 #if c_canEmbedBuffers
     const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
 #else
-    const auto kernelArgs = prepareGpuKernelArguments(
-            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->atoms.d_coefficients,
-            &kernelParamsPtr->grid.d_realGrid, &kernelParamsPtr->atoms.d_theta,
-            &kernelParamsPtr->atoms.d_dtheta, &kernelParamsPtr->atoms.d_gridlineIndices,
-            &kernelParamsPtr->atoms.d_forces);
+    const auto kernelArgs =
+            prepareGpuKernelArguments(kernelPtr,
+                                      config,
+                                      kernelParamsPtr,
+                                      &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_A],
+                                      &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_B],
+                                      &kernelParamsPtr->grid.d_realGrid[FEP_STATE_A],
+                                      &kernelParamsPtr->grid.d_realGrid[FEP_STATE_B],
+                                      &kernelParamsPtr->atoms.d_theta,
+                                      &kernelParamsPtr->atoms.d_dtheta,
+                                      &kernelParamsPtr->atoms.d_gridlineIndices,
+                                      &kernelParamsPtr->atoms.d_forces);
 #endif
-    launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent, "PME gather",
-                    kernelArgs);
+    launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent, "PME gather", kernelArgs);
     pme_gpu_stop_timing(pmeGpu, timingId);
 
     if (pmeGpu->settings.useGpuForceReduction)
@@ -1311,7 +1819,7 @@ void pme_gpu_gather(PmeGpu* pmeGpu, const float* h_grid)
     }
 }
 
-void* pme_gpu_get_kernelparam_forces(const PmeGpu* pmeGpu)
+DeviceBuffer<gmx::RVec> pme_gpu_get_kernelparam_forces(const PmeGpu* pmeGpu)
 {
     if (pmeGpu && pmeGpu->kernelParams)
     {
@@ -1319,7 +1827,7 @@ void* pme_gpu_get_kernelparam_forces(const PmeGpu* pmeGpu)
     }
     else
     {
-        return nullptr;
+        return DeviceBuffer<gmx::RVec>{};
     }
 }