Two sets of coefficients for Coulomb FEP PME on GPU
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gpu_internal.cpp
index 63b77aa86f81326aee15640c164f02ee64722ab1..635f0f1a62130fafa5b9aed31c2a0844d20d48f5 100644 (file)
@@ -133,26 +133,46 @@ void pme_gpu_synchronize(const PmeGpu* pmeGpu)
 void pme_gpu_alloc_energy_virial(PmeGpu* pmeGpu)
 {
     const size_t energyAndVirialSize = c_virialAndEnergyCount * sizeof(float);
-    allocateDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy, c_virialAndEnergyCount,
-                         pmeGpu->archSpecific->deviceContext_);
-    pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_virialAndEnergy), energyAndVirialSize);
+
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        allocateDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy[gridIndex],
+                             c_virialAndEnergyCount, pmeGpu->archSpecific->deviceContext_);
+        pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_virialAndEnergy[gridIndex]), energyAndVirialSize);
+    }
 }
 
 void pme_gpu_free_energy_virial(PmeGpu* pmeGpu)
 {
-    freeDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy);
-    pfree(pmeGpu->staging.h_virialAndEnergy);
-    pmeGpu->staging.h_virialAndEnergy = nullptr;
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        freeDeviceBuffer(&pmeGpu->kernelParams->constants.d_virialAndEnergy[gridIndex]);
+        pfree(pmeGpu->staging.h_virialAndEnergy[gridIndex]);
+        pmeGpu->staging.h_virialAndEnergy[gridIndex] = nullptr;
+    }
 }
 
 void pme_gpu_clear_energy_virial(const PmeGpu* pmeGpu)
 {
-    clearDeviceBufferAsync(&pmeGpu->kernelParams->constants.d_virialAndEnergy, 0,
-                           c_virialAndEnergyCount, pmeGpu->archSpecific->pmeStream_);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        clearDeviceBufferAsync(&pmeGpu->kernelParams->constants.d_virialAndEnergy[gridIndex], 0,
+                               c_virialAndEnergyCount, pmeGpu->archSpecific->pmeStream_);
+    }
 }
 
-void pme_gpu_realloc_and_copy_bspline_values(PmeGpu* pmeGpu)
+void pme_gpu_realloc_and_copy_bspline_values(PmeGpu* pmeGpu, const int gridIndex)
 {
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+    GMX_ASSERT(gridIndex < pmeGpu->common->ngrids,
+               "Invalid combination of gridIndex and number of grids");
+
     const int splineValuesOffset[DIM] = { 0, pmeGpu->kernelParams->grid.realGridSize[XX],
                                           pmeGpu->kernelParams->grid.realGridSize[XX]
                                                   + pmeGpu->kernelParams->grid.realGridSize[YY] };
@@ -161,33 +181,36 @@ void pme_gpu_realloc_and_copy_bspline_values(PmeGpu* pmeGpu)
     const int newSplineValuesSize = pmeGpu->kernelParams->grid.realGridSize[XX]
                                     + pmeGpu->kernelParams->grid.realGridSize[YY]
                                     + pmeGpu->kernelParams->grid.realGridSize[ZZ];
-    const bool shouldRealloc = (newSplineValuesSize > pmeGpu->archSpecific->splineValuesSize);
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli, newSplineValuesSize,
-                           &pmeGpu->archSpecific->splineValuesSize,
-                           &pmeGpu->archSpecific->splineValuesSizeAlloc,
+    const bool shouldRealloc = (newSplineValuesSize > pmeGpu->archSpecific->splineValuesSize[gridIndex]);
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli[gridIndex],
+                           newSplineValuesSize, &pmeGpu->archSpecific->splineValuesSize[gridIndex],
+                           &pmeGpu->archSpecific->splineValuesCapacity[gridIndex],
                            pmeGpu->archSpecific->deviceContext_);
     if (shouldRealloc)
     {
         /* Reallocate the host buffer */
-        pfree(pmeGpu->staging.h_splineModuli);
-        pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_splineModuli),
+        pfree(pmeGpu->staging.h_splineModuli[gridIndex]);
+        pmalloc(reinterpret_cast<void**>(&pmeGpu->staging.h_splineModuli[gridIndex]),
                 newSplineValuesSize * sizeof(float));
     }
     for (int i = 0; i < DIM; i++)
     {
-        memcpy(pmeGpu->staging.h_splineModuli + splineValuesOffset[i],
+        memcpy(pmeGpu->staging.h_splineModuli[gridIndex] + splineValuesOffset[i],
                pmeGpu->common->bsp_mod[i].data(), pmeGpu->common->bsp_mod[i].size() * sizeof(float));
     }
     /* TODO: pin original buffer instead! */
-    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli, pmeGpu->staging.h_splineModuli,
-                       0, newSplineValuesSize, pmeGpu->archSpecific->pmeStream_,
-                       pmeGpu->settings.transferKind, nullptr);
+    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli[gridIndex],
+                       pmeGpu->staging.h_splineModuli[gridIndex], 0, newSplineValuesSize,
+                       pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
 }
 
 void pme_gpu_free_bspline_values(const PmeGpu* pmeGpu)
 {
-    pfree(pmeGpu->staging.h_splineModuli);
-    freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        pfree(pmeGpu->staging.h_splineModuli[gridIndex]);
+        freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_splineModuli[gridIndex]);
+    }
 }
 
 void pme_gpu_realloc_forces(PmeGpu* pmeGpu)
@@ -224,16 +247,18 @@ void pme_gpu_copy_output_forces(PmeGpu* pmeGpu)
                          pmeGpu->settings.transferKind, nullptr);
 }
 
-void pme_gpu_realloc_and_copy_input_coefficients(PmeGpu* pmeGpu, const float* h_coefficients)
+void pme_gpu_realloc_and_copy_input_coefficients(const PmeGpu* pmeGpu,
+                                                 const float*  h_coefficients,
+                                                 const int     gridIndex)
 {
     GMX_ASSERT(h_coefficients, "Bad host-side charge buffer in PME GPU");
     const size_t newCoefficientsSize = pmeGpu->nAtomsAlloc;
     GMX_ASSERT(newCoefficientsSize > 0, "Bad number of atoms in PME GPU");
-    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients, newCoefficientsSize,
-                           &pmeGpu->archSpecific->coefficientsSize,
-                           &pmeGpu->archSpecific->coefficientsSizeAlloc,
+    reallocateDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex],
+                           newCoefficientsSize, &pmeGpu->archSpecific->coefficientsSize[gridIndex],
+                           &pmeGpu->archSpecific->coefficientsCapacity[gridIndex],
                            pmeGpu->archSpecific->deviceContext_);
-    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients,
+    copyToDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex],
                        const_cast<float*>(h_coefficients), 0, pmeGpu->kernelParams->atoms.nAtoms,
                        pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
 
@@ -241,14 +266,17 @@ void pme_gpu_realloc_and_copy_input_coefficients(PmeGpu* pmeGpu, const float* h_
     const size_t paddingCount = pmeGpu->nAtomsAlloc - paddingIndex;
     if (paddingCount > 0)
     {
-        clearDeviceBufferAsync(&pmeGpu->kernelParams->atoms.d_coefficients, paddingIndex,
+        clearDeviceBufferAsync(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex], paddingIndex,
                                paddingCount, pmeGpu->archSpecific->pmeStream_);
     }
 }
 
 void pme_gpu_free_coefficients(const PmeGpu* pmeGpu)
 {
-    freeDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        freeDeviceBuffer(&pmeGpu->kernelParams->atoms.d_coefficients[gridIndex]);
+    }
 }
 
 void pme_gpu_realloc_spline_data(PmeGpu* pmeGpu)
@@ -304,51 +332,65 @@ void pme_gpu_free_grid_indices(const PmeGpu* pmeGpu)
 
 void pme_gpu_realloc_grids(PmeGpu* pmeGpu)
 {
-    auto*     kernelParamsPtr = pmeGpu->kernelParams.get();
+    auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+
     const int newRealGridSize = kernelParamsPtr->grid.realGridSizePadded[XX]
                                 * kernelParamsPtr->grid.realGridSizePadded[YY]
                                 * kernelParamsPtr->grid.realGridSizePadded[ZZ];
     const int newComplexGridSize = kernelParamsPtr->grid.complexGridSizePadded[XX]
                                    * kernelParamsPtr->grid.complexGridSizePadded[YY]
                                    * kernelParamsPtr->grid.complexGridSizePadded[ZZ] * 2;
-    // Multiplied by 2 because we count complex grid size for complex numbers, but all allocations/pointers are float
-    if (pmeGpu->archSpecific->performOutOfPlaceFFT)
-    {
-        /* 2 separate grids */
-        reallocateDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid, newComplexGridSize,
-                               &pmeGpu->archSpecific->complexGridSize,
-                               &pmeGpu->archSpecific->complexGridSizeAlloc,
-                               pmeGpu->archSpecific->deviceContext_);
-        reallocateDeviceBuffer(
-                &kernelParamsPtr->grid.d_realGrid, newRealGridSize, &pmeGpu->archSpecific->realGridSize,
-                &pmeGpu->archSpecific->realGridSizeAlloc, pmeGpu->archSpecific->deviceContext_);
-    }
-    else
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
     {
-        /* A single buffer so that any grid will fit */
-        const int newGridsSize = std::max(newRealGridSize, newComplexGridSize);
-        reallocateDeviceBuffer(
-                &kernelParamsPtr->grid.d_realGrid, newGridsSize, &pmeGpu->archSpecific->realGridSize,
-                &pmeGpu->archSpecific->realGridSizeAlloc, pmeGpu->archSpecific->deviceContext_);
-        kernelParamsPtr->grid.d_fourierGrid   = kernelParamsPtr->grid.d_realGrid;
-        pmeGpu->archSpecific->complexGridSize = pmeGpu->archSpecific->realGridSize;
-        // the size might get used later for copying the grid
+        // Multiplied by 2 because we count complex grid size for complex numbers, but all allocations/pointers are float
+        if (pmeGpu->archSpecific->performOutOfPlaceFFT)
+        {
+            /* 2 separate grids */
+            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid[gridIndex], newComplexGridSize,
+                                   &pmeGpu->archSpecific->complexGridSize[gridIndex],
+                                   &pmeGpu->archSpecific->complexGridCapacity[gridIndex],
+                                   pmeGpu->archSpecific->deviceContext_);
+            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_realGrid[gridIndex], newRealGridSize,
+                                   &pmeGpu->archSpecific->realGridSize[gridIndex],
+                                   &pmeGpu->archSpecific->realGridCapacity[gridIndex],
+                                   pmeGpu->archSpecific->deviceContext_);
+        }
+        else
+        {
+            /* A single buffer so that any grid will fit */
+            const int newGridsSize = std::max(newRealGridSize, newComplexGridSize);
+            reallocateDeviceBuffer(&kernelParamsPtr->grid.d_realGrid[gridIndex], newGridsSize,
+                                   &pmeGpu->archSpecific->realGridSize[gridIndex],
+                                   &pmeGpu->archSpecific->realGridCapacity[gridIndex],
+                                   pmeGpu->archSpecific->deviceContext_);
+            kernelParamsPtr->grid.d_fourierGrid[gridIndex] = kernelParamsPtr->grid.d_realGrid[gridIndex];
+            pmeGpu->archSpecific->complexGridSize[gridIndex] =
+                    pmeGpu->archSpecific->realGridSize[gridIndex];
+            // the size might get used later for copying the grid
+        }
     }
 }
 
 void pme_gpu_free_grids(const PmeGpu* pmeGpu)
 {
-    if (pmeGpu->archSpecific->performOutOfPlaceFFT)
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
     {
-        freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_fourierGrid);
+        if (pmeGpu->archSpecific->performOutOfPlaceFFT)
+        {
+            freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_fourierGrid[gridIndex]);
+        }
+        freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex]);
     }
-    freeDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid);
 }
 
 void pme_gpu_clear_grids(const PmeGpu* pmeGpu)
 {
-    clearDeviceBufferAsync(&pmeGpu->kernelParams->grid.d_realGrid, 0,
-                           pmeGpu->archSpecific->realGridSize, pmeGpu->archSpecific->pmeStream_);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        clearDeviceBufferAsync(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex], 0,
+                               pmeGpu->archSpecific->realGridSize[gridIndex],
+                               pmeGpu->archSpecific->pmeStream_);
+    }
 }
 
 void pme_gpu_realloc_and_copy_fract_shifts(PmeGpu* pmeGpu)
@@ -395,17 +437,18 @@ bool pme_gpu_stream_query(const PmeGpu* pmeGpu)
     return haveStreamTasksCompleted(pmeGpu->archSpecific->pmeStream_);
 }
 
-void pme_gpu_copy_input_gather_grid(const PmeGpu* pmeGpu, float* h_grid)
+void pme_gpu_copy_input_gather_grid(const PmeGpu* pmeGpu, const float* h_grid, const int gridIndex)
 {
-    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid, h_grid, 0, pmeGpu->archSpecific->realGridSize,
+    copyToDeviceBuffer(&pmeGpu->kernelParams->grid.d_realGrid[gridIndex], h_grid, 0,
+                       pmeGpu->archSpecific->realGridSize[gridIndex],
                        pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
 }
 
-void pme_gpu_copy_output_spread_grid(const PmeGpu* pmeGpu, float* h_grid)
+void pme_gpu_copy_output_spread_grid(const PmeGpu* pmeGpu, float* h_grid, const int gridIndex)
 {
-    copyFromDeviceBuffer(h_grid, &pmeGpu->kernelParams->grid.d_realGrid, 0,
-                         pmeGpu->archSpecific->realGridSize, pmeGpu->archSpecific->pmeStream_,
-                         pmeGpu->settings.transferKind, nullptr);
+    copyFromDeviceBuffer(h_grid, &pmeGpu->kernelParams->grid.d_realGrid[gridIndex], 0,
+                         pmeGpu->archSpecific->realGridSize[gridIndex],
+                         pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
     pmeGpu->archSpecific->syncSpreadGridD2H.markEvent(pmeGpu->archSpecific->pmeStream_);
 }
 
@@ -488,9 +531,9 @@ void pme_gpu_reinit_3dfft(const PmeGpu* pmeGpu)
     if (pme_gpu_settings(pmeGpu).performGPUFFT)
     {
         pmeGpu->archSpecific->fftSetup.resize(0);
-        for (int i = 0; i < pmeGpu->common->ngrids; i++)
+        for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
         {
-            pmeGpu->archSpecific->fftSetup.push_back(std::make_unique<GpuParallel3dFft>(pmeGpu));
+            pmeGpu->archSpecific->fftSetup.push_back(std::make_unique<GpuParallel3dFft>(pmeGpu, gridIndex));
         }
     }
 }
@@ -500,26 +543,62 @@ void pme_gpu_destroy_3dfft(const PmeGpu* pmeGpu)
     pmeGpu->archSpecific->fftSetup.resize(0);
 }
 
-void pme_gpu_getEnergyAndVirial(const gmx_pme_t& pme, PmeOutput* output)
+void pme_gpu_getEnergyAndVirial(const gmx_pme_t& pme, const float lambda, PmeOutput* output)
 {
     const PmeGpu* pmeGpu = pme.gpu;
-    for (int j = 0; j < c_virialAndEnergyCount; j++)
-    {
-        GMX_ASSERT(std::isfinite(pmeGpu->staging.h_virialAndEnergy[j]),
-                   "PME GPU produces incorrect energy/virial.");
-    }
-
-    unsigned int j                 = 0;
-    output->coulombVirial_[XX][XX] = 0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[YY][YY] = 0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[ZZ][ZZ] = 0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[XX][YY] = output->coulombVirial_[YY][XX] =
-            0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[XX][ZZ] = output->coulombVirial_[ZZ][XX] =
-            0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombVirial_[YY][ZZ] = output->coulombVirial_[ZZ][YY] =
-            0.25F * pmeGpu->staging.h_virialAndEnergy[j++];
-    output->coulombEnergy_ = 0.5F * pmeGpu->staging.h_virialAndEnergy[j++];
+
+    GMX_ASSERT(lambda == 1.0 || pmeGpu->common->ngrids == 2,
+               "Invalid combination of lambda and number of grids");
+
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        for (int j = 0; j < c_virialAndEnergyCount; j++)
+        {
+            GMX_ASSERT(std::isfinite(pmeGpu->staging.h_virialAndEnergy[gridIndex][j]),
+                       "PME GPU produces incorrect energy/virial.");
+        }
+    }
+    for (int dim1 = 0; dim1 < DIM; dim1++)
+    {
+        for (int dim2 = 0; dim2 < DIM; dim2++)
+        {
+            output->coulombVirial_[dim1][dim2] = 0;
+        }
+    }
+    output->coulombEnergy_ = 0;
+    float scale            = 1.0;
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        if (pmeGpu->common->ngrids == 2)
+        {
+            scale = gridIndex == 0 ? (1.0 - lambda) : lambda;
+        }
+        output->coulombVirial_[XX][XX] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][0];
+        output->coulombVirial_[YY][YY] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][1];
+        output->coulombVirial_[ZZ][ZZ] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][2];
+        output->coulombVirial_[XX][YY] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][3];
+        output->coulombVirial_[YY][XX] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][3];
+        output->coulombVirial_[XX][ZZ] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][4];
+        output->coulombVirial_[ZZ][XX] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][4];
+        output->coulombVirial_[YY][ZZ] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][5];
+        output->coulombVirial_[ZZ][YY] +=
+                scale * 0.25F * pmeGpu->staging.h_virialAndEnergy[gridIndex][5];
+        output->coulombEnergy_ += scale * 0.5F * pmeGpu->staging.h_virialAndEnergy[gridIndex][6];
+    }
+    if (pmeGpu->common->ngrids > 1)
+    {
+        output->coulombDvdl_ = 0.5F
+                               * (pmeGpu->staging.h_virialAndEnergy[FEP_STATE_B][6]
+                                  - pmeGpu->staging.h_virialAndEnergy[FEP_STATE_A][6]);
+    }
 }
 
 /*! \brief Sets the force-related members in \p output
@@ -536,7 +615,7 @@ static void pme_gpu_getForceOutput(PmeGpu* pmeGpu, PmeOutput* output)
     }
 }
 
-PmeOutput pme_gpu_getOutput(const gmx_pme_t& pme, const bool computeEnergyAndVirial)
+PmeOutput pme_gpu_getOutput(const gmx_pme_t& pme, const bool computeEnergyAndVirial, const real lambdaQ)
 {
     PmeGpu* pmeGpu = pme.gpu;
 
@@ -548,7 +627,7 @@ PmeOutput pme_gpu_getOutput(const gmx_pme_t& pme, const bool computeEnergyAndVir
     {
         if (pme_gpu_settings(pmeGpu).performGPUSolve)
         {
-            pme_gpu_getEnergyAndVirial(pme, &output);
+            pme_gpu_getEnergyAndVirial(pme, lambdaQ, &output);
         }
         else
         {
@@ -590,9 +669,13 @@ void pme_gpu_update_input_box(PmeGpu gmx_unused* pmeGpu, const matrix gmx_unused
 static void pme_gpu_reinit_grids(PmeGpu* pmeGpu)
 {
     auto* kernelParamsPtr = pme_gpu_get_kernel_params_base_ptr(pmeGpu);
+
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+
     kernelParamsPtr->grid.ewaldFactor =
             (M_PI * M_PI) / (pmeGpu->common->ewaldcoeff_q * pmeGpu->common->ewaldcoeff_q);
-
     /* The grid size variants */
     for (int i = 0; i < DIM; i++)
     {
@@ -613,14 +696,17 @@ static void pme_gpu_reinit_grids(PmeGpu* pmeGpu)
         kernelParamsPtr->grid.realGridSizePadded[ZZ] =
                 (kernelParamsPtr->grid.realGridSize[ZZ] / 2 + 1) * 2;
     }
-
     /* GPU FFT: n real elements correspond to (n / 2 + 1) complex elements in minor dimension */
     kernelParamsPtr->grid.complexGridSize[ZZ] /= 2;
     kernelParamsPtr->grid.complexGridSize[ZZ]++;
     kernelParamsPtr->grid.complexGridSizePadded[ZZ] = kernelParamsPtr->grid.complexGridSize[ZZ];
 
     pme_gpu_realloc_and_copy_fract_shifts(pmeGpu);
-    pme_gpu_realloc_and_copy_bspline_values(pmeGpu);
+    for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+    {
+        pme_gpu_realloc_and_copy_bspline_values(pmeGpu, gridIndex);
+    }
+
     pme_gpu_realloc_grids(pmeGpu);
     pme_gpu_reinit_3dfft(pmeGpu);
 }
@@ -641,7 +727,7 @@ static void pme_gpu_copy_common_data_from(const gmx_pme_t* pme)
     /* TODO: Consider refactoring the CPU PME code to use the same structure,
      * so that this function becomes 2 lines */
     PmeGpu* pmeGpu               = pme->gpu;
-    pmeGpu->common->ngrids       = pme->ngrids;
+    pmeGpu->common->ngrids       = pme->bFEP_q ? 2 : 1;
     pmeGpu->common->epsilon_r    = pme->epsilon_r;
     pmeGpu->common->ewaldcoeff_q = pme->ewaldcoeff_q;
     pmeGpu->common->nk[XX]       = pme->nkx;
@@ -725,9 +811,9 @@ static void pme_gpu_init(gmx_pme_t*           pme,
     pmeGpu->initializedClfftLibrary_ = std::make_unique<gmx::ClfftInitializer>();
 
     pme_gpu_init_internal(pmeGpu, deviceContext, deviceStream);
-    pme_gpu_alloc_energy_virial(pmeGpu);
 
     pme_gpu_copy_common_data_from(pme);
+    pme_gpu_alloc_energy_virial(pmeGpu);
 
     GMX_ASSERT(pmeGpu->common->epsilon_r != 0.0F, "PME GPU: bad electrostatic coefficient");
 
@@ -804,7 +890,7 @@ void pme_gpu_destroy(PmeGpu* pmeGpu)
     delete pmeGpu;
 }
 
-void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* charges)
+void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* chargesA, const real* chargesB)
 {
     auto* kernelParamsPtr         = pme_gpu_get_kernel_params_base_ptr(pmeGpu);
     kernelParamsPtr->atoms.nAtoms = nAtoms;
@@ -817,8 +903,25 @@ void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* charges)
     GMX_RELEASE_ASSERT(false, "Only single precision supported");
     GMX_UNUSED_VALUE(charges);
 #else
-    pme_gpu_realloc_and_copy_input_coefficients(pmeGpu, reinterpret_cast<const float*>(charges));
+    int gridIndex = 0;
     /* Could also be checked for haveToRealloc, but the copy always needs to be performed */
+    pme_gpu_realloc_and_copy_input_coefficients(pmeGpu, reinterpret_cast<const float*>(chargesA), gridIndex);
+    gridIndex++;
+    if (chargesB != nullptr)
+    {
+        pme_gpu_realloc_and_copy_input_coefficients(
+                pmeGpu, reinterpret_cast<const float*>(chargesB), gridIndex);
+    }
+    else
+    {
+        /* Fill the second set of coefficients with chargesA as well to be able to avoid
+         * conditionals in the GPU kernels */
+        /* FIXME: This should be avoided by making a separate templated version of the
+         * relevant kernel(s) (probably only pme_gather_kernel). That would require a
+         * reduction of the current number of templated parameters of that kernel. */
+        pme_gpu_realloc_and_copy_input_coefficients(
+                pmeGpu, reinterpret_cast<const float*>(chargesA), gridIndex);
+    }
 #endif
 
     if (haveToRealloc)
@@ -849,7 +952,7 @@ static CommandEvent* pme_gpu_fetch_timing_event(const PmeGpu* pmeGpu, size_t PME
     return timingEvent;
 }
 
-void pme_gpu_3dfft(const PmeGpu* pmeGpu, gmx_fft_direction dir, int grid_index)
+void pme_gpu_3dfft(const PmeGpu* pmeGpu, gmx_fft_direction dir, const int grid_index)
 {
     int timerId = (dir == GMX_FFT_REAL_TO_COMPLEX) ? gtPME_FFT_R2C : gtPME_FFT_C2R;
 
@@ -881,32 +984,64 @@ std::pair<int, int> inline pmeGpuCreateGrid(const PmeGpu* pmeGpu, int blockCount
  * \param[in]  pmeGpu                   The PME GPU structure.
  * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  writeSplinesToGlobal     bool controlling if we should write spline data to global memory
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
-static auto selectSplineAndSpreadKernelPtr(const PmeGpu* pmeGpu, ThreadsPerAtom threadsPerAtom, bool writeSplinesToGlobal)
+static auto selectSplineAndSpreadKernelPtr(const PmeGpu*  pmeGpu,
+                                           ThreadsPerAtom threadsPerAtom,
+                                           bool           writeSplinesToGlobal,
+                                           const int      numGrids)
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (writeSplinesToGlobal)
     {
         if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplines;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelWriteSplinesSingle;
+            }
         }
     }
     else
     {
         if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelSingle;
+            }
         }
     }
 
@@ -919,12 +1054,14 @@ static auto selectSplineAndSpreadKernelPtr(const PmeGpu* pmeGpu, ThreadsPerAtom
  * \param[in]  pmeGpu                   The PME GPU structure.
  * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  writeSplinesToGlobal     bool controlling if we should write spline data to global memory
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
 static auto selectSplineKernelPtr(const PmeGpu*  pmeGpu,
                                   ThreadsPerAtom threadsPerAtom,
-                                  bool gmx_unused writeSplinesToGlobal)
+                                  bool gmx_unused writeSplinesToGlobal,
+                                  const int       numGrids)
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     GMX_ASSERT(
@@ -933,11 +1070,25 @@ static auto selectSplineKernelPtr(const PmeGpu*  pmeGpu,
 
     if (threadsPerAtom == ThreadsPerAtom::Order)
     {
-        kernelPtr = pmeGpu->programHandle_->impl_->splineKernelThPerAtom4;
+        if (numGrids == 2)
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelThPerAtom4Dual;
+        }
+        else
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelThPerAtom4Single;
+        }
     }
     else
     {
-        kernelPtr = pmeGpu->programHandle_->impl_->splineKernel;
+        if (numGrids == 2)
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelDual;
+        }
+        else
+        {
+            kernelPtr = pmeGpu->programHandle_->impl_->splineKernelSingle;
+        }
     }
     return kernelPtr;
 }
@@ -948,21 +1099,38 @@ static auto selectSplineKernelPtr(const PmeGpu*  pmeGpu,
  * \param[in]  pmeGpu                   The PME GPU structure.
  * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  writeSplinesToGlobal     bool controlling if we should write spline data to global memory
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
-static auto selectSpreadKernelPtr(const PmeGpu* pmeGpu, ThreadsPerAtom threadsPerAtom, bool writeSplinesToGlobal)
+static auto selectSpreadKernelPtr(const PmeGpu*  pmeGpu,
+                                  ThreadsPerAtom threadsPerAtom,
+                                  bool           writeSplinesToGlobal,
+                                  const int      numGrids)
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (writeSplinesToGlobal)
     {
         if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelThPerAtom4Dual;
+            }
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->spreadKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->spreadKernelSingle;
+            }
         }
     }
     else
@@ -971,11 +1139,25 @@ static auto selectSpreadKernelPtr(const PmeGpu* pmeGpu, ThreadsPerAtom threadsPe
            using the spline and spread Kernel */
         if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->splineAndSpreadKernelSingle;
+            }
         }
     }
     return kernelPtr;
@@ -983,14 +1165,18 @@ static auto selectSpreadKernelPtr(const PmeGpu* pmeGpu, ThreadsPerAtom threadsPe
 
 void pme_gpu_spread(const PmeGpu*         pmeGpu,
                     GpuEventSynchronizer* xReadyOnDevice,
-                    int gmx_unused gridIndex,
-                    real*          h_grid,
-                    bool           computeSplines,
-                    bool           spreadCharges)
+                    real**                h_grids,
+                    bool                  computeSplines,
+                    bool                  spreadCharges,
+                    const real            lambda)
 {
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+
     GMX_ASSERT(computeSplines || spreadCharges,
                "PME spline/spread kernel has invalid input (nothing to do)");
-    const auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+    auto* kernelParamsPtr = pmeGpu->kernelParams.get();
     GMX_ASSERT(kernelParamsPtr->atoms.nAtoms > 0, "No atom data in PME GPU spread");
 
     const size_t blockSize = pmeGpu->programHandle_->impl_->spreadWorkGroupSize;
@@ -1031,6 +1217,15 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
     const int blockCount = pmeGpu->nAtomsAlloc / atomsPerBlock;
     auto      dimGrid    = pmeGpuCreateGrid(pmeGpu, blockCount);
 
+    if (pmeGpu->common->ngrids == 1)
+    {
+        kernelParamsPtr->current.scale = 1.0;
+    }
+    else
+    {
+        kernelParamsPtr->current.scale = 1.0 - lambda;
+    }
+
     KernelLaunchConfig config;
     config.blockSize[0] = order;
     config.blockSize[1] = (pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::Order ? 1 : order);
@@ -1046,20 +1241,22 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
         {
             timingId  = gtPME_SPLINEANDSPREAD;
             kernelPtr = selectSplineAndSpreadKernelPtr(pmeGpu, pmeGpu->settings.threadsPerAtom,
-                                                       writeGlobal || (!recalculateSplines));
+                                                       writeGlobal || (!recalculateSplines),
+                                                       pmeGpu->common->ngrids);
         }
         else
         {
             timingId  = gtPME_SPLINE;
             kernelPtr = selectSplineKernelPtr(pmeGpu, pmeGpu->settings.threadsPerAtom,
-                                              writeGlobal || (!recalculateSplines));
+                                              writeGlobal || (!recalculateSplines),
+                                              pmeGpu->common->ngrids);
         }
     }
     else
     {
         timingId  = gtPME_SPREAD;
         kernelPtr = selectSpreadKernelPtr(pmeGpu, pmeGpu->settings.threadsPerAtom,
-                                          writeGlobal || (!recalculateSplines));
+                                          writeGlobal || (!recalculateSplines), pmeGpu->common->ngrids);
     }
 
 
@@ -1071,9 +1268,10 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
     const auto kernelArgs = prepareGpuKernelArguments(
             kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->atoms.d_theta,
             &kernelParamsPtr->atoms.d_dtheta, &kernelParamsPtr->atoms.d_gridlineIndices,
-            &kernelParamsPtr->grid.d_realGrid, &kernelParamsPtr->grid.d_fractShiftsTable,
-            &kernelParamsPtr->grid.d_gridlineIndicesTable, &kernelParamsPtr->atoms.d_coefficients,
-            &kernelParamsPtr->atoms.d_coordinates);
+            &kernelParamsPtr->grid.d_realGrid[FEP_STATE_A], &kernelParamsPtr->grid.d_realGrid[FEP_STATE_B],
+            &kernelParamsPtr->grid.d_fractShiftsTable, &kernelParamsPtr->grid.d_gridlineIndicesTable,
+            &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_A],
+            &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_B], &kernelParamsPtr->atoms.d_coordinates);
 #endif
 
     launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent,
@@ -1084,7 +1282,11 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
     const bool copyBackGrid = spreadCharges && (!settings.performGPUFFT || settings.copyAllOutputs);
     if (copyBackGrid)
     {
-        pme_gpu_copy_output_spread_grid(pmeGpu, h_grid);
+        for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+        {
+            float* h_grid = h_grids[gridIndex];
+            pme_gpu_copy_output_spread_grid(pmeGpu, h_grid, gridIndex);
+        }
     }
     const bool copyBackAtomData =
             computeSplines && (!settings.performGPUGather || settings.copyAllOutputs);
@@ -1094,8 +1296,18 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
     }
 }
 
-void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrdering, bool computeEnergyAndVirial)
+void pme_gpu_solve(const PmeGpu* pmeGpu,
+                   const int     gridIndex,
+                   t_complex*    h_grid,
+                   GridOrdering  gridOrdering,
+                   bool          computeEnergyAndVirial)
 {
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+    GMX_ASSERT(gridIndex < pmeGpu->common->ngrids,
+               "Invalid combination of gridIndex and number of grids");
+
     const auto& settings               = pmeGpu->settings;
     const bool  copyInputAndOutputGrid = !settings.performGPUFFT || settings.copyAllOutputs;
 
@@ -1104,9 +1316,9 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
     float* h_gridFloat = reinterpret_cast<float*>(h_grid);
     if (copyInputAndOutputGrid)
     {
-        copyToDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid, h_gridFloat, 0,
-                           pmeGpu->archSpecific->complexGridSize, pmeGpu->archSpecific->pmeStream_,
-                           pmeGpu->settings.transferKind, nullptr);
+        copyToDeviceBuffer(&kernelParamsPtr->grid.d_fourierGrid[gridIndex], h_gridFloat, 0,
+                           pmeGpu->archSpecific->complexGridSize[gridIndex],
+                           pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
     }
 
     int majorDim = -1, middleDim = -1, minorDim = -1;
@@ -1160,13 +1372,29 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (gridOrdering == GridOrdering::YZX)
     {
-        kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveYZXEnergyKernel
-                                           : pmeGpu->programHandle_->impl_->solveYZXKernel;
+        if (gridIndex == 0)
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveYZXEnergyKernelA
+                                               : pmeGpu->programHandle_->impl_->solveYZXKernelA;
+        }
+        else
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveYZXEnergyKernelB
+                                               : pmeGpu->programHandle_->impl_->solveYZXKernelB;
+        }
     }
     else if (gridOrdering == GridOrdering::XYZ)
     {
-        kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveXYZEnergyKernel
-                                           : pmeGpu->programHandle_->impl_->solveXYZKernel;
+        if (gridIndex == 0)
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveXYZEnergyKernelA
+                                               : pmeGpu->programHandle_->impl_->solveXYZKernelA;
+        }
+        else
+        {
+            kernelPtr = computeEnergyAndVirial ? pmeGpu->programHandle_->impl_->solveXYZEnergyKernelB
+                                               : pmeGpu->programHandle_->impl_->solveXYZKernelB;
+        }
     }
 
     pme_gpu_start_timing(pmeGpu, timingId);
@@ -1175,8 +1403,9 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
     const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
 #else
     const auto kernelArgs = prepareGpuKernelArguments(
-            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->grid.d_splineModuli,
-            &kernelParamsPtr->constants.d_virialAndEnergy, &kernelParamsPtr->grid.d_fourierGrid);
+            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->grid.d_splineModuli[gridIndex],
+            &kernelParamsPtr->constants.d_virialAndEnergy[gridIndex],
+            &kernelParamsPtr->grid.d_fourierGrid[gridIndex]);
 #endif
     launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent, "PME solve",
                     kernelArgs);
@@ -1184,16 +1413,17 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
 
     if (computeEnergyAndVirial)
     {
-        copyFromDeviceBuffer(pmeGpu->staging.h_virialAndEnergy,
-                             &kernelParamsPtr->constants.d_virialAndEnergy, 0, c_virialAndEnergyCount,
-                             pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
+        copyFromDeviceBuffer(pmeGpu->staging.h_virialAndEnergy[gridIndex],
+                             &kernelParamsPtr->constants.d_virialAndEnergy[gridIndex], 0,
+                             c_virialAndEnergyCount, pmeGpu->archSpecific->pmeStream_,
+                             pmeGpu->settings.transferKind, nullptr);
     }
 
     if (copyInputAndOutputGrid)
     {
-        copyFromDeviceBuffer(h_gridFloat, &kernelParamsPtr->grid.d_fourierGrid, 0,
-                             pmeGpu->archSpecific->complexGridSize, pmeGpu->archSpecific->pmeStream_,
-                             pmeGpu->settings.transferKind, nullptr);
+        copyFromDeviceBuffer(h_gridFloat, &kernelParamsPtr->grid.d_fourierGrid[gridIndex], 0,
+                             pmeGpu->archSpecific->complexGridSize[gridIndex],
+                             pmeGpu->archSpecific->pmeStream_, pmeGpu->settings.transferKind, nullptr);
     }
 }
 
@@ -1203,10 +1433,14 @@ void pme_gpu_solve(const PmeGpu* pmeGpu, t_complex* h_grid, GridOrdering gridOrd
  * \param[in]  pmeGpu                   The PME GPU structure.
  * \param[in]  threadsPerAtom           Controls whether we should use order or order*order threads per atom
  * \param[in]  readSplinesFromGlobal    bool controlling if we should write spline data to global memory
+ * \param[in]  numGrids                 Number of grids to use. numGrids == 2 if Coulomb is perturbed.
  *
  * \return Pointer to CUDA kernel
  */
-inline auto selectGatherKernelPtr(const PmeGpu* pmeGpu, ThreadsPerAtom threadsPerAtom, bool readSplinesFromGlobal)
+inline auto selectGatherKernelPtr(const PmeGpu*  pmeGpu,
+                                  ThreadsPerAtom threadsPerAtom,
+                                  bool           readSplinesFromGlobal,
+                                  const int      numGrids)
 
 {
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
@@ -1215,34 +1449,70 @@ inline auto selectGatherKernelPtr(const PmeGpu* pmeGpu, ThreadsPerAtom threadsPe
     {
         if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplines;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelReadSplinesSingle;
+            }
         }
     }
     else
     {
         if (threadsPerAtom == ThreadsPerAtom::Order)
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelThPerAtom4;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelThPerAtom4Dual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelThPerAtom4Single;
+            }
         }
         else
         {
-            kernelPtr = pmeGpu->programHandle_->impl_->gatherKernel;
+            if (numGrids == 2)
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelDual;
+            }
+            else
+            {
+                kernelPtr = pmeGpu->programHandle_->impl_->gatherKernelSingle;
+            }
         }
     }
     return kernelPtr;
 }
 
-
-void pme_gpu_gather(PmeGpu* pmeGpu, const float* h_grid)
+void pme_gpu_gather(PmeGpu* pmeGpu, real** h_grids, const float lambda)
 {
+    GMX_ASSERT(
+            pmeGpu->common->ngrids == 1 || pmeGpu->common->ngrids == 2,
+            "Only one (normal Coulomb PME) or two (FEP coulomb PME) PME grids can be used on GPU");
+
     const auto& settings = pmeGpu->settings;
+
     if (!settings.performGPUFFT || settings.copyAllOutputs)
     {
-        pme_gpu_copy_input_gather_grid(pmeGpu, const_cast<float*>(h_grid));
+        for (int gridIndex = 0; gridIndex < pmeGpu->common->ngrids; gridIndex++)
+        {
+            float* h_grid = const_cast<float*>(h_grids[gridIndex]);
+            pme_gpu_copy_input_gather_grid(pmeGpu, h_grid, gridIndex);
+        }
     }
 
     if (settings.copyAllOutputs)
@@ -1280,22 +1550,33 @@ void pme_gpu_gather(PmeGpu* pmeGpu, const float* h_grid)
 
     // TODO test different cache configs
 
-    int                                timingId  = gtPME_GATHER;
-    PmeGpuProgramImpl::PmeKernelHandle kernelPtr = selectGatherKernelPtr(
-            pmeGpu, pmeGpu->settings.threadsPerAtom, readGlobal || (!recalculateSplines));
+    int                                timingId = gtPME_GATHER;
+    PmeGpuProgramImpl::PmeKernelHandle kernelPtr =
+            selectGatherKernelPtr(pmeGpu, pmeGpu->settings.threadsPerAtom,
+                                  readGlobal || (!recalculateSplines), pmeGpu->common->ngrids);
     // TODO design kernel selection getters and make PmeGpu a friend of PmeGpuProgramImpl
 
     pme_gpu_start_timing(pmeGpu, timingId);
-    auto*       timingEvent     = pme_gpu_fetch_timing_event(pmeGpu, timingId);
-    const auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+    auto* timingEvent     = pme_gpu_fetch_timing_event(pmeGpu, timingId);
+    auto* kernelParamsPtr = pmeGpu->kernelParams.get();
+    if (pmeGpu->common->ngrids == 1)
+    {
+        kernelParamsPtr->current.scale = 1.0;
+    }
+    else
+    {
+        kernelParamsPtr->current.scale = 1.0 - lambda;
+    }
+
 #if c_canEmbedBuffers
     const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
 #else
     const auto kernelArgs = prepareGpuKernelArguments(
-            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->atoms.d_coefficients,
-            &kernelParamsPtr->grid.d_realGrid, &kernelParamsPtr->atoms.d_theta,
-            &kernelParamsPtr->atoms.d_dtheta, &kernelParamsPtr->atoms.d_gridlineIndices,
-            &kernelParamsPtr->atoms.d_forces);
+            kernelPtr, config, kernelParamsPtr, &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_A],
+            &kernelParamsPtr->atoms.d_coefficients[FEP_STATE_B],
+            &kernelParamsPtr->grid.d_realGrid[FEP_STATE_A], &kernelParamsPtr->grid.d_realGrid[FEP_STATE_B],
+            &kernelParamsPtr->atoms.d_theta, &kernelParamsPtr->atoms.d_dtheta,
+            &kernelParamsPtr->atoms.d_gridlineIndices, &kernelParamsPtr->atoms.d_forces);
 #endif
     launchGpuKernel(kernelPtr, config, pmeGpu->archSpecific->pmeStream_, timingEvent, "PME gather",
                     kernelArgs);