Disable PME Mixed mode with FEP
[alexxy/gromacs.git] / src / gromacs / ewald / pme.cpp
index d6cd44b346b0708f15cdc3588aebf3fbf3d46d14..b0b59e6eedbaf7fe8f37134805eab50a5280b711 100644 (file)
@@ -3,8 +3,8 @@
  *
  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
  * Copyright (c) 2001-2004, The GROMACS development team.
- * Copyright (c) 2013,2014,2015,2016,2017 by the GROMACS development team.
- * Copyright (c) 2018,2019,2020, by the GROMACS development team, led by
+ * Copyright (c) 2013,2014,2015,2016,2017 The GROMACS development team.
+ * Copyright (c) 2018,2019,2020,2021, by the GROMACS development team, led by
  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
  * and including many others, as listed in the AUTHORS file in the
  * top-level source directory and at http://www.gromacs.org.
@@ -99,6 +99,7 @@
 #include "gromacs/mdtypes/forcerec.h"
 #include "gromacs/mdtypes/inputrec.h"
 #include "gromacs/mdtypes/md_enums.h"
+#include "gromacs/mdtypes/simulation_workload.h"
 #include "gromacs/pbcutil/pbc.h"
 #include "gromacs/timing/cyclecounter.h"
 #include "gromacs/timing/wallcycle.h"
@@ -158,10 +159,14 @@ bool pme_gpu_supports_build(std::string* error)
     {
         errorReasons.emplace_back("a double-precision build");
     }
-    if (GMX_GPU == GMX_GPU_NONE)
+    if (!GMX_GPU)
     {
         errorReasons.emplace_back("a non-GPU build");
     }
+    if (GMX_GPU_SYCL)
+    {
+        errorReasons.emplace_back("SYCL build"); // SYCL-TODO
+    }
     return addMessageIfNotSupported(errorReasons, error);
 }
 
@@ -169,7 +174,7 @@ bool pme_gpu_supports_hardware(const gmx_hw_info_t gmx_unused& hwinfo, std::stri
 {
     std::list<std::string> errorReasons;
 
-    if (GMX_GPU == GMX_GPU_OPENCL)
+    if (GMX_GPU_OPENCL)
     {
 #ifdef __APPLE__
         errorReasons.emplace_back("Apple OS X operating system");
@@ -178,7 +183,7 @@ bool pme_gpu_supports_hardware(const gmx_hw_info_t gmx_unused& hwinfo, std::stri
     return addMessageIfNotSupported(errorReasons, error);
 }
 
-bool pme_gpu_supports_input(const t_inputrec& ir, const gmx_mtop_t& mtop, std::string* error)
+bool pme_gpu_supports_input(const t_inputrec& ir, std::string* error)
 {
     std::list<std::string> errorReasons;
     if (!EEL_PME(ir.coulombtype))
@@ -189,21 +194,25 @@ bool pme_gpu_supports_input(const t_inputrec& ir, const gmx_mtop_t& mtop, std::s
     {
         errorReasons.emplace_back("interpolation orders other than 4");
     }
-    if (ir.efep != efepNO)
-    {
-        if (gmx_mtop_has_perturbed_charges(mtop))
-        {
-            errorReasons.emplace_back(
-                    "free energy calculations with perturbed charges (multiple grids)");
-        }
-    }
     if (EVDW_PME(ir.vdwtype))
     {
         errorReasons.emplace_back("Lennard-Jones PME");
     }
     if (!EI_DYNAMICS(ir.eI))
     {
-        errorReasons.emplace_back("not a dynamical integrator");
+        errorReasons.emplace_back(
+                "Cannot compute PME interactions on a GPU, because PME GPU requires a dynamical "
+                "integrator (md, sd, etc).");
+    }
+    return addMessageIfNotSupported(errorReasons, error);
+}
+
+bool pme_gpu_mixed_mode_supports_input(const t_inputrec& ir, std::string* error)
+{
+    std::list<std::string> errorReasons;
+    if (ir.efep != efepNO)
+    {
+        errorReasons.emplace_back("Free Energy Perturbation (in PME GPU mixed mode)");
     }
     return addMessageIfNotSupported(errorReasons, error);
 }
@@ -228,10 +237,6 @@ static bool pme_gpu_check_restrictions(const gmx_pme_t* pme, std::string* error)
     {
         errorReasons.emplace_back("interpolation orders other than 4");
     }
-    if (pme->bFEP)
-    {
-        errorReasons.emplace_back("free energy calculations (multiple grids)");
-    }
     if (pme->doLJ)
     {
         errorReasons.emplace_back("Lennard-Jones PME");
@@ -240,11 +245,14 @@ static bool pme_gpu_check_restrictions(const gmx_pme_t* pme, std::string* error)
     {
         errorReasons.emplace_back("double precision");
     }
-    if (GMX_GPU == GMX_GPU_NONE)
+    if (!GMX_GPU)
     {
         errorReasons.emplace_back("non-GPU build of GROMACS");
     }
-
+    if (GMX_GPU_SYCL)
+    {
+        errorReasons.emplace_back("SYCL build of GROMACS"); // SYCL-TODO
+    }
     return addMessageIfNotSupported(errorReasons, error);
 }
 
@@ -559,20 +567,21 @@ static int div_round_up(int enumerator, int denominator)
     return (enumerator + denominator - 1) / denominator;
 }
 
-gmx_pme_t* gmx_pme_init(const t_commrec*         cr,
-                        const NumPmeDomains&     numPmeDomains,
-                        const t_inputrec*        ir,
-                        gmx_bool                 bFreeEnergy_q,
-                        gmx_bool                 bFreeEnergy_lj,
-                        gmx_bool                 bReproducible,
-                        real                     ewaldcoeff_q,
-                        real                     ewaldcoeff_lj,
-                        int                      nthread,
-                        PmeRunMode               runMode,
-                        PmeGpu*                  pmeGpu,
-                        const gmx_device_info_t* gpuInfo,
-                        const PmeGpuProgram*     pmeGpuProgram,
-                        const gmx::MDLogger& /*mdlog*/)
+gmx_pme_t* gmx_pme_init(const t_commrec*     cr,
+                        const NumPmeDomains& numPmeDomains,
+                        const t_inputrec*    ir,
+                        gmx_bool             bFreeEnergy_q,
+                        gmx_bool             bFreeEnergy_lj,
+                        gmx_bool             bReproducible,
+                        real                 ewaldcoeff_q,
+                        real                 ewaldcoeff_lj,
+                        int                  nthread,
+                        PmeRunMode           runMode,
+                        PmeGpu*              pmeGpu,
+                        const DeviceContext* deviceContext,
+                        const DeviceStream*  deviceStream,
+                        const PmeGpuProgram* pmeGpuProgram,
+                        const gmx::MDLogger& mdlog)
 {
     int  use_threads, sum_use_threads, i;
     ivec ndata;
@@ -741,17 +750,19 @@ gmx_pme_t* gmx_pme_init(const t_commrec*         cr,
         imbal = estimate_pme_load_imbalance(pme.get());
         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
         {
-            fprintf(stderr,
-                    "\n"
-                    "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
-                    "      For optimal PME load balancing\n"
-                    "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_ranks_x "
-                    "(%d)\n"
-                    "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_ranks_y "
-                    "(%d)\n"
-                    "\n",
-                    gmx::roundToInt((imbal - 1) * 100), pme->nkx, pme->nky, pme->nnodes_major,
-                    pme->nky, pme->nkz, pme->nnodes_minor);
+            GMX_LOG(mdlog.warning)
+                    .asParagraph()
+                    .appendTextFormatted(
+                            "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
+                            "      For optimal PME load balancing\n"
+                            "      PME grid_x (%d) and grid_y (%d) should be divisible by "
+                            "#PME_ranks_x "
+                            "(%d)\n"
+                            "      and PME grid_y (%d) and grid_z (%d) should be divisible by "
+                            "#PME_ranks_y "
+                            "(%d)",
+                            gmx::roundToInt((imbal - 1) * 100), pme->nkx, pme->nky,
+                            pme->nnodes_major, pme->nky, pme->nkz, pme->nnodes_minor);
         }
     }
 
@@ -882,8 +893,13 @@ gmx_pme_t* gmx_pme_init(const t_commrec*         cr,
         {
             GMX_THROW(gmx::NotImplementedError(errorString));
         }
+        pme_gpu_reinit(pme.get(), deviceContext, deviceStream, pmeGpuProgram);
+    }
+    else
+    {
+        GMX_ASSERT(pme->gpu == nullptr, "Should not have PME GPU object when PME is on a CPU.");
     }
-    pme_gpu_reinit(pme.get(), gpuInfo, pmeGpuProgram);
+
 
     pme_init_all_work(&pme->solve_work, pme->nthread, pme->nkx);
 
@@ -916,21 +932,21 @@ void gmx_pme_reinit(struct gmx_pme_t** pmedata,
 
     try
     {
+        // This is reinit. Any logging should have been done at first init.
+        // Here we should avoid writing notes for settings the user did not
+        // set directly.
         const gmx::MDLogger dummyLogger;
-        // This is reinit which is currently only changing grid size/coefficients,
-        // so we don't expect the actual logging.
-        // TODO: when PME is an object, it should take reference to mdlog on construction and save it.
         GMX_ASSERT(pmedata, "Invalid PME pointer");
         NumPmeDomains numPmeDomains = { pme_src->nnodes_major, pme_src->nnodes_minor };
         *pmedata = gmx_pme_init(cr, numPmeDomains, &irc, pme_src->bFEP_q, pme_src->bFEP_lj, FALSE,
                                 ewaldcoeff_q, ewaldcoeff_lj, pme_src->nthread, pme_src->runMode,
-                                pme_src->gpu, nullptr, nullptr, dummyLogger);
+                                pme_src->gpu, nullptr, nullptr, nullptr, dummyLogger);
         /* When running PME on the CPU not using domain decomposition,
          * the atom data is allocated once only in gmx_pme_(re)init().
          */
         if (!pme_src->gpu && pme_src->nnodes == 1)
         {
-            gmx_pme_reinit_atoms(*pmedata, pme_src->atc[0].numAtoms(), nullptr);
+            gmx_pme_reinit_atoms(*pmedata, pme_src->atc[0].numAtoms(), nullptr, nullptr);
         }
         // TODO this is mostly passing around current values
     }
@@ -1016,7 +1032,7 @@ int gmx_pme_do(struct gmx_pme_t*              pme,
                real                           lambda_lj,
                real*                          dvdlambda_q,
                real*                          dvdlambda_lj,
-               int                            flags)
+               const gmx::StepWorkload&       stepWork)
 {
     GMX_ASSERT(pme->runMode == PmeRunMode::CPU,
                "gmx_pme_do should not be called on the GPU PME run.");
@@ -1036,18 +1052,17 @@ int gmx_pme_do(struct gmx_pme_t*              pme,
     gmx_bool             bFirst, bDoSplines;
     int                  fep_state;
     int                  fep_states_lj = pme->bFEP_lj ? 2 : 1;
-    const gmx_bool       bCalcEnerVir  = (flags & GMX_PME_CALC_ENER_VIR) != 0;
-    const gmx_bool       bBackFFT      = (flags & (GMX_PME_CALC_F | GMX_PME_CALC_POT)) != 0;
-    const gmx_bool       bCalcF        = (flags & GMX_PME_CALC_F) != 0;
+    // There's no support for computing energy without virial, or vice versa
+    const bool computeEnergyAndVirial = (stepWork.computeEnergy || stepWork.computeVirial);
 
-    /* We could be passing lambda!=1 while no q or LJ is actually perturbed */
+    /* We could be passing lambda!=0 while no q or LJ is actually perturbed */
     if (!pme->bFEP_q)
     {
-        lambda_q = 1;
+        lambda_q = 0;
     }
     if (!pme->bFEP_lj)
     {
-        lambda_lj = 1;
+        lambda_lj = 0;
     }
 
     assert(pme->nnodes > 0);
@@ -1153,41 +1168,37 @@ int gmx_pme_do(struct gmx_pme_t*              pme,
             fprintf(debug, "Rank= %6d, pme local particles=%6d\n", cr->nodeid, atc.numAtoms());
         }
 
-        if (flags & GMX_PME_SPREAD)
+        wallcycle_start(wcycle, ewcPME_SPREAD);
+
+        /* Spread the coefficients on a grid */
+        spread_on_grid(pme, &atc, pmegrid, bFirst, TRUE, fftgrid, bDoSplines, grid_index);
+
+        if (bFirst)
         {
-            wallcycle_start(wcycle, ewcPME_SPREAD);
+            inc_nrnb(nrnb, eNR_WEIGHTS, DIM * atc.numAtoms());
+        }
+        inc_nrnb(nrnb, eNR_SPREADBSP, pme->pme_order * pme->pme_order * pme->pme_order * atc.numAtoms());
 
-            /* Spread the coefficients on a grid */
-            spread_on_grid(pme, &atc, pmegrid, bFirst, TRUE, fftgrid, bDoSplines, grid_index);
+        if (!pme->bUseThreads)
+        {
+            wrap_periodic_pmegrid(pme, grid);
 
-            if (bFirst)
+            /* sum contributions to local grid from other nodes */
+            if (pme->nnodes > 1)
             {
-                inc_nrnb(nrnb, eNR_WEIGHTS, DIM * atc.numAtoms());
+                gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_FORWARD);
             }
-            inc_nrnb(nrnb, eNR_SPREADBSP,
-                     pme->pme_order * pme->pme_order * pme->pme_order * atc.numAtoms());
-
-            if (!pme->bUseThreads)
-            {
-                wrap_periodic_pmegrid(pme, grid);
 
-                /* sum contributions to local grid from other nodes */
-                if (pme->nnodes > 1)
-                {
-                    gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_FORWARD);
-                }
-
-                copy_pmegrid_to_fftgrid(pme, grid, fftgrid, grid_index);
-            }
+            copy_pmegrid_to_fftgrid(pme, grid, fftgrid, grid_index);
+        }
 
-            wallcycle_stop(wcycle, ewcPME_SPREAD);
+        wallcycle_stop(wcycle, ewcPME_SPREAD);
 
-            /* TODO If the OpenMP and single-threaded implementations
-               converge, then spread_on_grid() and
-               copy_pmegrid_to_fftgrid() will perhaps live in the same
-               source file.
-             */
-        }
+        /* TODO If the OpenMP and single-threaded implementations
+           converge, then spread_on_grid() and
+           copy_pmegrid_to_fftgrid() will perhaps live in the same
+           source file.
+        */
 
         /* Here we start a large thread parallel region */
 #pragma omp parallel num_threads(pme->nthread) private(thread)
@@ -1195,75 +1206,69 @@ int gmx_pme_do(struct gmx_pme_t*              pme,
             try
             {
                 thread = gmx_omp_get_thread_num();
-                if (flags & GMX_PME_SOLVE)
-                {
-                    int loop_count;
+                int loop_count;
 
-                    /* do 3d-fft */
-                    if (thread == 0)
-                    {
-                        wallcycle_start(wcycle, ewcPME_FFT);
-                    }
-                    gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX, thread, wcycle);
-                    if (thread == 0)
-                    {
-                        wallcycle_stop(wcycle, ewcPME_FFT);
-                    }
-
-                    /* solve in k-space for our local cells */
-                    if (thread == 0)
-                    {
-                        wallcycle_start(wcycle, (grid_index < DO_Q ? ewcPME_SOLVE : ewcLJPME));
-                    }
-                    if (grid_index < DO_Q)
-                    {
-                        loop_count = solve_pme_yzx(
-                                pme, cfftgrid, scaledBox[XX][XX] * scaledBox[YY][YY] * scaledBox[ZZ][ZZ],
-                                bCalcEnerVir, pme->nthread, thread);
-                    }
-                    else
-                    {
-                        loop_count = solve_pme_lj_yzx(
-                                pme, &cfftgrid, FALSE,
-                                scaledBox[XX][XX] * scaledBox[YY][YY] * scaledBox[ZZ][ZZ],
-                                bCalcEnerVir, pme->nthread, thread);
-                    }
+                /* do 3d-fft */
+                if (thread == 0)
+                {
+                    wallcycle_start(wcycle, ewcPME_FFT);
+                }
+                gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX, thread, wcycle);
+                if (thread == 0)
+                {
+                    wallcycle_stop(wcycle, ewcPME_FFT);
+                }
 
-                    if (thread == 0)
-                    {
-                        wallcycle_stop(wcycle, (grid_index < DO_Q ? ewcPME_SOLVE : ewcLJPME));
-                        inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
-                    }
+                /* solve in k-space for our local cells */
+                if (thread == 0)
+                {
+                    wallcycle_start(wcycle, (grid_index < DO_Q ? ewcPME_SOLVE : ewcLJPME));
+                }
+                if (grid_index < DO_Q)
+                {
+                    loop_count = solve_pme_yzx(
+                            pme, cfftgrid, scaledBox[XX][XX] * scaledBox[YY][YY] * scaledBox[ZZ][ZZ],
+                            computeEnergyAndVirial, pme->nthread, thread);
+                }
+                else
+                {
+                    loop_count =
+                            solve_pme_lj_yzx(pme, &cfftgrid, FALSE,
+                                             scaledBox[XX][XX] * scaledBox[YY][YY] * scaledBox[ZZ][ZZ],
+                                             computeEnergyAndVirial, pme->nthread, thread);
                 }
 
-                if (bBackFFT)
+                if (thread == 0)
                 {
-                    /* do 3d-invfft */
-                    if (thread == 0)
-                    {
-                        wallcycle_start(wcycle, ewcPME_FFT);
-                    }
-                    gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL, thread, wcycle);
-                    if (thread == 0)
-                    {
-                        wallcycle_stop(wcycle, ewcPME_FFT);
+                    wallcycle_stop(wcycle, (grid_index < DO_Q ? ewcPME_SOLVE : ewcLJPME));
+                    inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
+                }
 
+                /* do 3d-invfft */
+                if (thread == 0)
+                {
+                    wallcycle_start(wcycle, ewcPME_FFT);
+                }
+                gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL, thread, wcycle);
+                if (thread == 0)
+                {
+                    wallcycle_stop(wcycle, ewcPME_FFT);
 
-                        if (pme->nodeid == 0)
-                        {
-                            real ntot = pme->nkx * pme->nky * pme->nkz;
-                            npme      = static_cast<int>(ntot * std::log(ntot) / std::log(2.0));
-                            inc_nrnb(nrnb, eNR_FFT, 2 * npme);
-                        }
 
-                        /* Note: this wallcycle region is closed below
-                           outside an OpenMP region, so take care if
-                           refactoring code here. */
-                        wallcycle_start(wcycle, ewcPME_GATHER);
+                    if (pme->nodeid == 0)
+                    {
+                        real ntot = pme->nkx * pme->nky * pme->nkz;
+                        npme      = static_cast<int>(ntot * std::log(ntot) / std::log(2.0));
+                        inc_nrnb(nrnb, eNR_FFT, 2 * npme);
                     }
 
-                    copy_fftgrid_to_pmegrid(pme, fftgrid, grid, grid_index, pme->nthread, thread);
+                    /* Note: this wallcycle region is closed below
+                       outside an OpenMP region, so take care if
+                       refactoring code here. */
+                    wallcycle_start(wcycle, ewcPME_GATHER);
                 }
+
+                copy_fftgrid_to_pmegrid(pme, fftgrid, grid, grid_index, pme->nthread, thread);
             }
             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
         }
@@ -1271,18 +1276,15 @@ int gmx_pme_do(struct gmx_pme_t*              pme,
          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
          */
 
-        if (bBackFFT)
+        /* distribute local grid to all nodes */
+        if (pme->nnodes > 1)
         {
-            /* distribute local grid to all nodes */
-            if (pme->nnodes > 1)
-            {
-                gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_BACKWARD);
-            }
-
-            unwrap_periodic_pmegrid(pme, grid);
+            gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_BACKWARD);
         }
 
-        if (bCalcF)
+        unwrap_periodic_pmegrid(pme, grid);
+
+        if (stepWork.computeForces)
         {
             /* interpolate forces for our local atoms */
 
@@ -1312,7 +1314,7 @@ int gmx_pme_do(struct gmx_pme_t*              pme,
             wallcycle_stop(wcycle, ewcPME_GATHER);
         }
 
-        if (bCalcEnerVir)
+        if (computeEnergyAndVirial)
         {
             /* This should only be called on the master thread
              * and after the threads have synchronized.
@@ -1404,85 +1406,77 @@ int gmx_pme_do(struct gmx_pme_t*              pme,
                 calc_next_lb_coeffs(coefficientBuffer, local_sigma);
                 grid = pmegrid->grid.grid;
 
-                if (flags & GMX_PME_SPREAD)
-                {
-                    wallcycle_start(wcycle, ewcPME_SPREAD);
-                    /* Spread the c6 on a grid */
-                    spread_on_grid(pme, &atc, pmegrid, bFirst, TRUE, fftgrid, bDoSplines, grid_index);
+                wallcycle_start(wcycle, ewcPME_SPREAD);
+                /* Spread the c6 on a grid */
+                spread_on_grid(pme, &atc, pmegrid, bFirst, TRUE, fftgrid, bDoSplines, grid_index);
 
-                    if (bFirst)
-                    {
-                        inc_nrnb(nrnb, eNR_WEIGHTS, DIM * atc.numAtoms());
-                    }
+                if (bFirst)
+                {
+                    inc_nrnb(nrnb, eNR_WEIGHTS, DIM * atc.numAtoms());
+                }
 
-                    inc_nrnb(nrnb, eNR_SPREADBSP,
-                             pme->pme_order * pme->pme_order * pme->pme_order * atc.numAtoms());
-                    if (pme->nthread == 1)
+                inc_nrnb(nrnb, eNR_SPREADBSP,
+                         pme->pme_order * pme->pme_order * pme->pme_order * atc.numAtoms());
+                if (pme->nthread == 1)
+                {
+                    wrap_periodic_pmegrid(pme, grid);
+                    /* sum contributions to local grid from other nodes */
+                    if (pme->nnodes > 1)
                     {
-                        wrap_periodic_pmegrid(pme, grid);
-                        /* sum contributions to local grid from other nodes */
-                        if (pme->nnodes > 1)
-                        {
-                            gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_FORWARD);
-                        }
-                        copy_pmegrid_to_fftgrid(pme, grid, fftgrid, grid_index);
+                        gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_FORWARD);
                     }
-                    wallcycle_stop(wcycle, ewcPME_SPREAD);
+                    copy_pmegrid_to_fftgrid(pme, grid, fftgrid, grid_index);
                 }
+                wallcycle_stop(wcycle, ewcPME_SPREAD);
+
                 /*Here we start a large thread parallel region*/
 #pragma omp parallel num_threads(pme->nthread) private(thread)
                 {
                     try
                     {
                         thread = gmx_omp_get_thread_num();
-                        if (flags & GMX_PME_SOLVE)
+                        /* do 3d-fft */
+                        if (thread == 0)
                         {
-                            /* do 3d-fft */
-                            if (thread == 0)
-                            {
-                                wallcycle_start(wcycle, ewcPME_FFT);
-                            }
+                            wallcycle_start(wcycle, ewcPME_FFT);
+                        }
 
-                            gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX, thread, wcycle);
-                            if (thread == 0)
-                            {
-                                wallcycle_stop(wcycle, ewcPME_FFT);
-                            }
+                        gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX, thread, wcycle);
+                        if (thread == 0)
+                        {
+                            wallcycle_stop(wcycle, ewcPME_FFT);
                         }
                     }
                     GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
                 }
                 bFirst = FALSE;
             }
-            if (flags & GMX_PME_SOLVE)
-            {
-                /* solve in k-space for our local cells */
+            /* solve in k-space for our local cells */
 #pragma omp parallel num_threads(pme->nthread) private(thread)
+            {
+                try
                 {
-                    try
+                    int loop_count;
+                    thread = gmx_omp_get_thread_num();
+                    if (thread == 0)
                     {
-                        int loop_count;
-                        thread = gmx_omp_get_thread_num();
-                        if (thread == 0)
-                        {
-                            wallcycle_start(wcycle, ewcLJPME);
-                        }
+                        wallcycle_start(wcycle, ewcLJPME);
+                    }
 
-                        loop_count = solve_pme_lj_yzx(
-                                pme, &pme->cfftgrid[2], TRUE,
-                                scaledBox[XX][XX] * scaledBox[YY][YY] * scaledBox[ZZ][ZZ],
-                                bCalcEnerVir, pme->nthread, thread);
-                        if (thread == 0)
-                        {
-                            wallcycle_stop(wcycle, ewcLJPME);
-                            inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
-                        }
+                    loop_count =
+                            solve_pme_lj_yzx(pme, &pme->cfftgrid[2], TRUE,
+                                             scaledBox[XX][XX] * scaledBox[YY][YY] * scaledBox[ZZ][ZZ],
+                                             computeEnergyAndVirial, pme->nthread, thread);
+                    if (thread == 0)
+                    {
+                        wallcycle_stop(wcycle, ewcLJPME);
+                        inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
                     }
-                    GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
                 }
+                GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
             }
 
-            if (bCalcEnerVir)
+            if (computeEnergyAndVirial)
             {
                 /* This should only be called on the master thread and
                  * after the threads have synchronized.
@@ -1490,88 +1484,85 @@ int gmx_pme_do(struct gmx_pme_t*              pme,
                 get_pme_ener_vir_lj(pme->solve_work, pme->nthread, &output[fep_state]);
             }
 
-            if (bBackFFT)
+            bFirst = !pme->doCoulomb;
+            calc_initial_lb_coeffs(coefficientBuffer, local_c6, local_sigma);
+            for (grid_index = 8; grid_index >= 2; --grid_index)
             {
-                bFirst = !pme->doCoulomb;
-                calc_initial_lb_coeffs(coefficientBuffer, local_c6, local_sigma);
-                for (grid_index = 8; grid_index >= 2; --grid_index)
-                {
-                    /* Unpack structure */
-                    pmegrid    = &pme->pmegrid[grid_index];
-                    fftgrid    = pme->fftgrid[grid_index];
-                    pfft_setup = pme->pfft_setup[grid_index];
-                    grid       = pmegrid->grid.grid;
-                    calc_next_lb_coeffs(coefficientBuffer, local_sigma);
+                /* Unpack structure */
+                pmegrid    = &pme->pmegrid[grid_index];
+                fftgrid    = pme->fftgrid[grid_index];
+                pfft_setup = pme->pfft_setup[grid_index];
+                grid       = pmegrid->grid.grid;
+                calc_next_lb_coeffs(coefficientBuffer, local_sigma);
 #pragma omp parallel num_threads(pme->nthread) private(thread)
+                {
+                    try
                     {
-                        try
+                        thread = gmx_omp_get_thread_num();
+                        /* do 3d-invfft */
+                        if (thread == 0)
                         {
-                            thread = gmx_omp_get_thread_num();
-                            /* do 3d-invfft */
-                            if (thread == 0)
-                            {
-                                wallcycle_start(wcycle, ewcPME_FFT);
-                            }
+                            wallcycle_start(wcycle, ewcPME_FFT);
+                        }
 
-                            gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL, thread, wcycle);
-                            if (thread == 0)
-                            {
-                                wallcycle_stop(wcycle, ewcPME_FFT);
+                        gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL, thread, wcycle);
+                        if (thread == 0)
+                        {
+                            wallcycle_stop(wcycle, ewcPME_FFT);
 
 
-                                if (pme->nodeid == 0)
-                                {
-                                    real ntot = pme->nkx * pme->nky * pme->nkz;
-                                    npme = static_cast<int>(ntot * std::log(ntot) / std::log(2.0));
-                                    inc_nrnb(nrnb, eNR_FFT, 2 * npme);
-                                }
-                                wallcycle_start(wcycle, ewcPME_GATHER);
+                            if (pme->nodeid == 0)
+                            {
+                                real ntot = pme->nkx * pme->nky * pme->nkz;
+                                npme      = static_cast<int>(ntot * std::log(ntot) / std::log(2.0));
+                                inc_nrnb(nrnb, eNR_FFT, 2 * npme);
                             }
-
-                            copy_fftgrid_to_pmegrid(pme, fftgrid, grid, grid_index, pme->nthread, thread);
+                            wallcycle_start(wcycle, ewcPME_GATHER);
                         }
-                        GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
-                    } /*#pragma omp parallel*/
 
-                    /* distribute local grid to all nodes */
-                    if (pme->nnodes > 1)
-                    {
-                        gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_BACKWARD);
+                        copy_fftgrid_to_pmegrid(pme, fftgrid, grid, grid_index, pme->nthread, thread);
                     }
+                    GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
+                } /*#pragma omp parallel*/
 
-                    unwrap_periodic_pmegrid(pme, grid);
+                /* distribute local grid to all nodes */
+                if (pme->nnodes > 1)
+                {
+                    gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_BACKWARD);
+                }
 
-                    if (bCalcF)
-                    {
-                        /* interpolate forces for our local atoms */
-                        bClearF = (bFirst && PAR(cr));
-                        scale   = pme->bFEP ? (fep_state < 1 ? 1.0 - lambda_lj : lambda_lj) : 1.0;
-                        scale *= lb_scale_factor[grid_index - 2];
+                unwrap_periodic_pmegrid(pme, grid);
+
+                if (stepWork.computeForces)
+                {
+                    /* interpolate forces for our local atoms */
+                    bClearF = (bFirst && PAR(cr));
+                    scale   = pme->bFEP ? (fep_state < 1 ? 1.0 - lambda_lj : lambda_lj) : 1.0;
+                    scale *= lb_scale_factor[grid_index - 2];
 
 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
-                        for (thread = 0; thread < pme->nthread; thread++)
+                    for (thread = 0; thread < pme->nthread; thread++)
+                    {
+                        try
                         {
-                            try
-                            {
-                                gather_f_bsplines(pme, grid, bClearF, &pme->atc[0],
-                                                  &pme->atc[0].spline[thread], scale);
-                            }
-                            GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
+                            gather_f_bsplines(pme, grid, bClearF, &pme->atc[0],
+                                              &pme->atc[0].spline[thread], scale);
                         }
+                        GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
+                    }
 
 
-                        inc_nrnb(nrnb, eNR_GATHERFBSP,
-                                 pme->pme_order * pme->pme_order * pme->pme_order * pme->atc[0].numAtoms());
-                    }
-                    wallcycle_stop(wcycle, ewcPME_GATHER);
+                    inc_nrnb(nrnb, eNR_GATHERFBSP,
+                             pme->pme_order * pme->pme_order * pme->pme_order * pme->atc[0].numAtoms());
+                }
+                wallcycle_stop(wcycle, ewcPME_GATHER);
 
-                    bFirst = FALSE;
-                } /* for (grid_index = 8; grid_index >= 2; --grid_index) */
-            }     /* if (bCalcF) */
-        }         /* for (fep_state = 0; fep_state < fep_states_lj; ++fep_state) */
-    }             /* if ((flags & GMX_PME_DO_LJ) && pme->ljpme_combination_rule == eljpmeLB) */
+                bFirst = FALSE;
+            } /* for (grid_index = 8; grid_index >= 2; --grid_index) */
+        }     /* for (fep_state = 0; fep_state < fep_states_lj; ++fep_state) */
+    }         /* if (pme->doLJ && pme->ljpme_combination_rule == eljpmeLB) */
 
-    if (bCalcF && pme->nnodes > 1)
+    if (stepWork.computeForces && pme->nnodes > 1)
     {
         wallcycle_start(wcycle, ewcPME_REDISTXF);
         for (d = 0; d < pme->ndecompdim; d++)
@@ -1596,7 +1587,7 @@ int gmx_pme_do(struct gmx_pme_t*              pme,
         wallcycle_stop(wcycle, ewcPME_REDISTXF);
     }
 
-    if (bCalcEnerVir)
+    if (computeEnergyAndVirial)
     {
         if (pme->doCoulomb)
         {
@@ -1719,11 +1710,13 @@ void gmx_pme_destroy(gmx_pme_t* pme)
     delete pme;
 }
 
-void gmx_pme_reinit_atoms(gmx_pme_t* pme, const int numAtoms, const real* charges)
+void gmx_pme_reinit_atoms(gmx_pme_t* pme, const int numAtoms, const real* chargesA, const real* chargesB)
 {
     if (pme->gpu != nullptr)
     {
-        pme_gpu_reinit_atoms(pme->gpu, numAtoms, charges);
+        GMX_ASSERT(!(pme->bFEP_q && chargesB == nullptr),
+                   "B state charges must be specified if running Coulomb FEP on the GPU");
+        pme_gpu_reinit_atoms(pme->gpu, numAtoms, chargesA, pme->bFEP_q ? chargesB : nullptr);
     }
     else
     {