Pipeline GPU PME Spline/Spread with PP Comms
[alexxy/gromacs.git] / src / gromacs / ewald / tests / pmetestcommon.cpp
index 30c0c042a7f6f858818dccfa7133dc5803e03512..0016cd80241e43f304af1595ddd7fab1b694d799 100644 (file)
@@ -1,7 +1,8 @@
 /*
  * This file is part of the GROMACS molecular simulation package.
  *
- * Copyright (c) 2016,2017,2018,2019, by the GROMACS development team, led by
+ * Copyright (c) 2016,2017,2018,2019,2020 by the GROMACS development team.
+ * Copyright (c) 2021, by the GROMACS development team, led by
  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
  * and including many others, as listed in the AUTHORS file in the
  * top-level source directory and at http://www.gromacs.org.
 
 #include "gromacs/domdec/domdec.h"
 #include "gromacs/ewald/pme_gather.h"
+#include "gromacs/ewald/pme_gpu_calculate_splines.h"
+#include "gromacs/ewald/pme_gpu_constants.h"
 #include "gromacs/ewald/pme_gpu_internal.h"
+#include "gromacs/ewald/pme_gpu_staging.h"
 #include "gromacs/ewald/pme_grid.h"
 #include "gromacs/ewald/pme_internal.h"
 #include "gromacs/ewald/pme_redistribute.h"
@@ -57,6 +61,7 @@
 #include "gromacs/ewald/pme_spread.h"
 #include "gromacs/fft/parallel_3dfft.h"
 #include "gromacs/gpu_utils/gpu_utils.h"
+#include "gromacs/hardware/device_management.h"
 #include "gromacs/math/invertmatrix.h"
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/pbcutil/pbc.h"
 #include "gromacs/utility/gmxassert.h"
 #include "gromacs/utility/logger.h"
 #include "gromacs/utility/stringutil.h"
+#include "gromacs/ewald/pme_coordinate_receiver_gpu.h"
 
+#include "testutils/test_hardware_environment.h"
 #include "testutils/testasserts.h"
 
+class DeviceContext;
+
 namespace gmx
 {
 namespace test
 {
 
-bool pmeSupportsInputForMode(const gmx_hw_info_t &hwinfo,
-                             const t_inputrec    *inputRec,
-                             CodePath             mode)
+bool pmeSupportsInputForMode(const gmx_hw_info_t& hwinfo, const t_inputrec* inputRec, CodePath mode)
 {
-    bool       implemented;
-    gmx_mtop_t mtop;
+    bool implemented;
     switch (mode)
     {
-        case CodePath::CPU:
-            implemented = true;
-            break;
+        case CodePath::CPU: implemented = true; break;
 
         case CodePath::GPU:
-            implemented = (pme_gpu_supports_build(nullptr) &&
-                           pme_gpu_supports_hardware(hwinfo, nullptr) &&
-                           pme_gpu_supports_input(*inputRec, mtop, nullptr));
+            implemented = (pme_gpu_supports_build(nullptr) && pme_gpu_supports_hardware(hwinfo, nullptr)
+                           && pme_gpu_supports_input(*inputRec, nullptr));
             break;
 
-        default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
     }
     return implemented;
 }
@@ -107,20 +109,34 @@ uint64_t getSplineModuliDoublePrecisionUlps(int splineOrder)
 }
 
 //! PME initialization
-PmeSafePointer pmeInitWrapper(const t_inputrec         *inputRec,
-                              const CodePath            mode,
-                              const gmx_device_info_t  *gpuInfo,
-                              PmeGpuProgramHandle       pmeGpuProgram,
-                              const Matrix3x3          &box,
-                              const real                ewaldCoeff_q,
-                              const real                ewaldCoeff_lj)
+PmeSafePointer pmeInitWrapper(const t_inputrec*    inputRec,
+                              const CodePath       mode,
+                              const DeviceContext* deviceContext,
+                              const DeviceStream*  deviceStream,
+                              const PmeGpuProgram* pmeGpuProgram,
+                              const Matrix3x3&     box,
+                              const real           ewaldCoeff_q,
+                              const real           ewaldCoeff_lj)
 {
     const MDLogger dummyLogger;
     const auto     runMode       = (mode == CodePath::CPU) ? PmeRunMode::CPU : PmeRunMode::Mixed;
-    t_commrec      dummyCommrec  = {0};
+    t_commrec      dummyCommrec  = { 0 };
     NumPmeDomains  numPmeDomains = { 1, 1 };
-    gmx_pme_t     *pmeDataRaw    = gmx_pme_init(&dummyCommrec, numPmeDomains, inputRec, false, false, true,
-                                                ewaldCoeff_q, ewaldCoeff_lj, 1, runMode, nullptr, gpuInfo, pmeGpuProgram, dummyLogger);
+    gmx_pme_t*     pmeDataRaw    = gmx_pme_init(&dummyCommrec,
+                                         numPmeDomains,
+                                         inputRec,
+                                         false,
+                                         false,
+                                         true,
+                                         ewaldCoeff_q,
+                                         ewaldCoeff_lj,
+                                         1,
+                                         runMode,
+                                         nullptr,
+                                         deviceContext,
+                                         deviceStream,
+                                         pmeGpuProgram,
+                                         dummyLogger);
     PmeSafePointer pme(pmeDataRaw); // taking ownership
 
     // TODO get rid of this with proper matrix type
@@ -132,64 +148,52 @@ PmeSafePointer pmeInitWrapper(const t_inputrec         *inputRec,
             boxTemp[i][j] = box[i * DIM + j];
         }
     }
-    const char *boxError = check_box(-1, boxTemp);
+    const char* boxError = check_box(PbcType::Unset, boxTemp);
     GMX_RELEASE_ASSERT(boxError == nullptr, boxError);
 
     switch (mode)
     {
-        case CodePath::CPU:
-            invertBoxMatrix(boxTemp, pme->recipbox);
-            break;
+        case CodePath::CPU: invertBoxMatrix(boxTemp, pme->recipbox); break;
 
         case CodePath::GPU:
             pme_gpu_set_testing(pme->gpu, true);
             pme_gpu_update_input_box(pme->gpu, boxTemp);
             break;
 
-        default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
     }
 
     return pme;
 }
 
-//! Simple PME initialization based on input, no atom data
-PmeSafePointer pmeInitEmpty(const t_inputrec         *inputRec,
-                            const CodePath            mode,
-                            const gmx_device_info_t  *gpuInfo,
-                            PmeGpuProgramHandle       pmeGpuProgram,
-                            const Matrix3x3          &box,
-                            real                      ewaldCoeff_q,
-                            real                      ewaldCoeff_lj
-                            )
+PmeSafePointer pmeInitEmpty(const t_inputrec* inputRec)
 {
-    return pmeInitWrapper(inputRec, mode, gpuInfo, pmeGpuProgram, box, ewaldCoeff_q, ewaldCoeff_lj);
-    // hiding the fact that PME actually needs to know the number of atoms in advance
+    const Matrix3x3 defaultBox = { { 1.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F, 0.0F, 0.0F, 1.0F } };
+    return pmeInitWrapper(inputRec, CodePath::CPU, nullptr, nullptr, nullptr, defaultBox, 0.0F, 0.0F);
 }
 
 //! Make a GPU state-propagator manager
-std::unique_ptr<StatePropagatorDataGpu>
-makeStatePropagatorDataGpu(const gmx_pme_t &pme)
+std::unique_ptr<StatePropagatorDataGpu> makeStatePropagatorDataGpu(const gmx_pme_t&     pme,
+                                                                   const DeviceContext* deviceContext,
+                                                                   const DeviceStream* deviceStream)
 {
     // TODO: Pin the host buffer and use async memory copies
     // TODO: Special constructor for PME-only rank / PME-tests is used here. There should be a mechanism to
     //       restrict one from using other constructor here.
-    return std::make_unique<StatePropagatorDataGpu>(pme_gpu_get_device_stream(&pme),
-                                                    pme_gpu_get_device_context(&pme),
-                                                    GpuApiCallBehavior::Sync,
-                                                    pme_gpu_get_padding_size(&pme));
+    return std::make_unique<StatePropagatorDataGpu>(
+            deviceStream, *deviceContext, GpuApiCallBehavior::Sync, pme_gpu_get_block_size(&pme), nullptr);
 }
 
 //! PME initialization with atom data
-void pmeInitAtoms(gmx_pme_t               *pme,
-                  StatePropagatorDataGpu  *stateGpu,
+void pmeInitAtoms(gmx_pme_t*               pme,
+                  StatePropagatorDataGpu*  stateGpu,
                   const CodePath           mode,
-                  const CoordinatesVector &coordinates,
-                  const ChargesVector     &charges)
+                  const CoordinatesVectorcoordinates,
+                  const ChargesVector&     charges)
 {
-    const index  atomCount = coordinates.size();
+    const index atomCount = coordinates.size();
     GMX_RELEASE_ASSERT(atomCount == charges.ssize(), "Mismatch in atom data");
-    PmeAtomComm *atc = nullptr;
+    PmeAtomCommatc = nullptr;
 
     switch (mode)
     {
@@ -197,139 +201,175 @@ void pmeInitAtoms(gmx_pme_t               *pme,
             atc              = &(pme->atc[0]);
             atc->x           = coordinates;
             atc->coefficient = charges;
-            gmx_pme_reinit_atoms(pme, atomCount, charges.data());
+            gmx_pme_reinit_atoms(pme, atomCount, charges, {});
             /* With decomposition there would be more boilerplate atc code here, e.g. do_redist_pos_coeffs */
             break;
 
         case CodePath::GPU:
             // TODO: Avoid use of atc in the GPU code path
-            atc              = &(pme->atc[0]);
+            atc = &(pme->atc[0]);
             // We need to set atc->n for passing the size in the tests
             atc->setNumAtoms(atomCount);
-            gmx_pme_reinit_atoms(pme, atomCount, charges.data());
+            gmx_pme_reinit_atoms(pme, atomCount, charges, {});
 
             stateGpu->reinit(atomCount, atomCount);
-            stateGpu->copyCoordinatesToGpu(arrayRefFromArray(coordinates.data(), coordinates.size()), gmx::AtomLocality::All);
+            stateGpu->copyCoordinatesToGpu(arrayRefFromArray(coordinates.data(), coordinates.size()),
+                                           gmx::AtomLocality::Local);
             pme_gpu_set_kernelparam_coordinates(pme->gpu, stateGpu->getCoordinates());
 
             break;
 
-        default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
     }
 }
 
 //! Getting local PME real grid pointer for test I/O
-static real *pmeGetRealGridInternal(const gmx_pme_t *pme)
+static real* pmeGetRealGridInternal(const gmx_pme_t* pme)
 {
     const size_t gridIndex = 0;
     return pme->fftgrid[gridIndex];
 }
 
 //! Getting local PME real grid dimensions
-static void pmeGetRealGridSizesInternal(const gmx_pme_t      *pme,
-                                        CodePath              mode,
-                                        IVec                 &gridSize,       //NOLINT(google-runtime-references)
-                                        IVec                 &paddedGridSize) //NOLINT(google-runtime-references)
+static void pmeGetRealGridSizesInternal(const gmx_pme_tpme,
+                                        CodePath         mode,
+                                        IVecgridSize,       //NOLINT(google-runtime-references)
+                                        IVecpaddedGridSize) //NOLINT(google-runtime-references)
 {
     const size_t gridIndex = 0;
     IVec         gridOffsetUnused;
     switch (mode)
     {
         case CodePath::CPU:
-            gmx_parallel_3dfft_real_limits(pme->pfft_setup[gridIndex], gridSize, gridOffsetUnused, paddedGridSize);
+            gmx_parallel_3dfft_real_limits(
+                    pme->pfft_setup[gridIndex], gridSize, gridOffsetUnused, paddedGridSize);
             break;
 
         case CodePath::GPU:
             pme_gpu_get_real_grid_sizes(pme->gpu, &gridSize, &paddedGridSize);
             break;
 
-        default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
     }
 }
 
 //! Getting local PME complex grid pointer for test I/O
-static t_complex *pmeGetComplexGridInternal(const gmx_pme_t *pme)
+static t_complex* pmeGetComplexGridInternal(const gmx_pme_t* pme)
 {
     const size_t gridIndex = 0;
     return pme->cfftgrid[gridIndex];
 }
 
 //! Getting local PME complex grid dimensions
-static void pmeGetComplexGridSizesInternal(const gmx_pme_t      *pme,
-                                           IVec                 &gridSize,       //NOLINT(google-runtime-references)
-                                           IVec                 &paddedGridSize) //NOLINT(google-runtime-references)
+static void pmeGetComplexGridSizesInternal(const gmx_pme_tpme,
+                                           IVecgridSize,       //NOLINT(google-runtime-references)
+                                           IVecpaddedGridSize) //NOLINT(google-runtime-references)
 {
     const size_t gridIndex = 0;
     IVec         gridOffsetUnused, complexOrderUnused;
-    gmx_parallel_3dfft_complex_limits(pme->pfft_setup[gridIndex], complexOrderUnused, gridSize, gridOffsetUnused, paddedGridSize); //TODO: what about YZX ordering?
+    gmx_parallel_3dfft_complex_limits(
+            pme->pfft_setup[gridIndex], complexOrderUnused, gridSize, gridOffsetUnused, paddedGridSize); // TODO: what about YZX ordering?
 }
 
 //! Getting the PME grid memory buffer and its sizes - template definition
-template<typename ValueType> static void pmeGetGridAndSizesInternal(const gmx_pme_t * /*unused*/, CodePath /*unused*/, ValueType * & /*unused*/, IVec & /*unused*/, IVec & /*unused*/) //NOLINT(google-runtime-references)
+template<typename ValueType>
+static void pmeGetGridAndSizesInternal(const gmx_pme_t* /*unused*/,
+                                       CodePath /*unused*/,
+                                       ValueType*& /*unused*/, //NOLINT(google-runtime-references)
+                                       IVec& /*unused*/,       //NOLINT(google-runtime-references)
+                                       IVec& /*unused*/)       //NOLINT(google-runtime-references)
 {
     GMX_THROW(InternalError("Deleted function call"));
-    // explicitly deleting general template does not compile in clang/icc, see https://llvm.org/bugs/show_bug.cgi?id=17537
+    // explicitly deleting general template does not compile in clang, see https://llvm.org/bugs/show_bug.cgi?id=17537
 }
 
 //! Getting the PME real grid memory buffer and its sizes
-template<> void pmeGetGridAndSizesInternal<real>(const gmx_pme_t *pme, CodePath mode, real * &grid, IVec &gridSize, IVec &paddedGridSize)
+template<>
+void pmeGetGridAndSizesInternal<real>(const gmx_pme_t* pme, CodePath mode, real*& grid, IVec& gridSize, IVec& paddedGridSize)
 {
     grid = pmeGetRealGridInternal(pme);
     pmeGetRealGridSizesInternal(pme, mode, gridSize, paddedGridSize);
 }
 
 //! Getting the PME complex grid memory buffer and its sizes
-template<> void pmeGetGridAndSizesInternal<t_complex>(const gmx_pme_t *pme, CodePath /*unused*/, t_complex * &grid, IVec &gridSize, IVec &paddedGridSize)
+template<>
+void pmeGetGridAndSizesInternal<t_complex>(const gmx_pme_t* pme,
+                                           CodePath /*unused*/,
+                                           t_complex*& grid,
+                                           IVec&       gridSize,
+                                           IVec&       paddedGridSize)
 {
     grid = pmeGetComplexGridInternal(pme);
     pmeGetComplexGridSizesInternal(pme, gridSize, paddedGridSize);
 }
 
 //! PME spline calculation and charge spreading
-void pmePerformSplineAndSpread(gmx_pme_t *pme, CodePath mode, // TODO const qualifiers elsewhere
-                               bool computeSplines, bool spreadCharges)
+void pmePerformSplineAndSpread(gmx_pme_t* pme,
+                               CodePath   mode, // TODO const qualifiers elsewhere
+                               bool       computeSplines,
+                               bool       spreadCharges)
 {
     GMX_RELEASE_ASSERT(pme != nullptr, "PME data is not initialized");
-    PmeAtomComm    *atc                          = &(pme->atc[0]);
-    const size_t    gridIndex                    = 0;
-    const bool      computeSplinesForZeroCharges = true;
-    real           *fftgrid                      = spreadCharges ? pme->fftgrid[gridIndex] : nullptr;
-    real           *pmegrid                      = pme->pmegrid[gridIndex].grid.grid;
+    PmeAtomCommatc                          = &(pme->atc[0]);
+    const size_t gridIndex                    = 0;
+    const bool   computeSplinesForZeroCharges = true;
+    real**       fftgrid                      = spreadCharges ? pme->fftgrid : nullptr;
+    real*        pmegrid                      = pme->pmegrid[gridIndex].grid.grid;
 
     switch (mode)
     {
         case CodePath::CPU:
-            spread_on_grid(pme, atc, &pme->pmegrid[gridIndex], computeSplines, spreadCharges,
-                           fftgrid, computeSplinesForZeroCharges, gridIndex);
+            spread_on_grid(pme,
+                           atc,
+                           &pme->pmegrid[gridIndex],
+                           computeSplines,
+                           spreadCharges,
+                           fftgrid != nullptr ? fftgrid[gridIndex] : nullptr,
+                           computeSplinesForZeroCharges,
+                           gridIndex);
             if (spreadCharges && !pme->bUseThreads)
             {
                 wrap_periodic_pmegrid(pme, pmegrid);
-                copy_pmegrid_to_fftgrid(pme, pmegrid, fftgrid, gridIndex);
+                copy_pmegrid_to_fftgrid(
+                        pme, pmegrid, fftgrid != nullptr ? fftgrid[gridIndex] : nullptr, gridIndex);
             }
             break;
 
+/* The compiler will complain about passing fftgrid (converting double ** to float **) if using
+ * double precision. GPUs are not used with double precision anyhow. */
+#if !GMX_DOUBLE
         case CodePath::GPU:
         {
+            const real lambdaQ = 1.0;
             // no synchronization needed as x is transferred in the PME stream
-            GpuEventSynchronizer *xReadyOnDevice = nullptr;
-            pme_gpu_spread(pme->gpu, xReadyOnDevice, gridIndex, fftgrid, computeSplines, spreadCharges);
+            GpuEventSynchronizer* xReadyOnDevice = nullptr;
+
+            bool                           useGpuDirectComm         = false;
+            gmx::PmeCoordinateReceiverGpu* pmeCoordinateReceiverGpu = nullptr;
+
+            pme_gpu_spread(pme->gpu,
+                           xReadyOnDevice,
+                           fftgrid,
+                           computeSplines,
+                           spreadCharges,
+                           lambdaQ,
+                           useGpuDirectComm,
+                           pmeCoordinateReceiverGpu);
         }
         break;
+#endif
 
-        default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
     }
 }
 
 //! Getting the internal spline data buffer pointer
-static real *pmeGetSplineDataInternal(const gmx_pme_t *pme, PmeSplineDataType type, int dimIndex)
+static real* pmeGetSplineDataInternal(const gmx_pme_t* pme, PmeSplineDataType type, int dimIndex)
 {
     GMX_ASSERT((0 <= dimIndex) && (dimIndex < DIM), "Invalid dimension index");
-    const PmeAtomComm    *atc          = &(pme->atc[0]);
-    const size_t          threadIndex  = 0;
-    real                 *splineBuffer = nullptr;
+    const PmeAtomCommatc          = &(pme->atc[0]);
+    const size_t       threadIndex  = 0;
+    real*              splineBuffer = nullptr;
     switch (type)
     {
         case PmeSplineDataType::Values:
@@ -340,20 +380,23 @@ static real *pmeGetSplineDataInternal(const gmx_pme_t *pme, PmeSplineDataType ty
             splineBuffer = atc->spline[threadIndex].dtheta.coefficients[dimIndex];
             break;
 
-        default:
-            GMX_THROW(InternalError("Unknown spline data type"));
+        default: GMX_THROW(InternalError("Unknown spline data type"));
     }
     return splineBuffer;
 }
 
 //! PME solving
-void pmePerformSolve(const gmx_pme_t *pme, CodePath mode,
-                     PmeSolveAlgorithm method, real cellVolume,
-                     GridOrdering gridOrdering, bool computeEnergyAndVirial)
-{
-    t_complex      *h_grid                 = pmeGetComplexGridInternal(pme);
-    const bool      useLorentzBerthelot    = false;
-    const size_t    threadIndex            = 0;
+void pmePerformSolve(const gmx_pme_t*  pme,
+                     CodePath          mode,
+                     PmeSolveAlgorithm method,
+                     real              cellVolume,
+                     GridOrdering      gridOrdering,
+                     bool              computeEnergyAndVirial)
+{
+    t_complex*   h_grid              = pmeGetComplexGridInternal(pme);
+    const bool   useLorentzBerthelot = false;
+    const size_t threadIndex         = 0;
+    const size_t gridIndex           = 0;
     switch (mode)
     {
         case CodePath::CPU:
@@ -364,17 +407,20 @@ void pmePerformSolve(const gmx_pme_t *pme, CodePath mode,
             switch (method)
             {
                 case PmeSolveAlgorithm::Coulomb:
-                    solve_pme_yzx(pme, h_grid, cellVolume,
-                                  computeEnergyAndVirial, pme->nthread, threadIndex);
+                    solve_pme_yzx(pme, h_grid, cellVolume, computeEnergyAndVirial, pme->nthread, threadIndex);
                     break;
 
                 case PmeSolveAlgorithm::LennardJones:
-                    solve_pme_lj_yzx(pme, &h_grid, useLorentzBerthelot,
-                                     cellVolume, computeEnergyAndVirial, pme->nthread, threadIndex);
+                    solve_pme_lj_yzx(pme,
+                                     &h_grid,
+                                     useLorentzBerthelot,
+                                     cellVolume,
+                                     computeEnergyAndVirial,
+                                     pme->nthread,
+                                     threadIndex);
                     break;
 
-                default:
-                    GMX_THROW(InternalError("Test not implemented for this mode"));
+                default: GMX_THROW(InternalError("Test not implemented for this mode"));
             }
             break;
 
@@ -382,32 +428,28 @@ void pmePerformSolve(const gmx_pme_t *pme, CodePath mode,
             switch (method)
             {
                 case PmeSolveAlgorithm::Coulomb:
-                    pme_gpu_solve(pme->gpu, h_grid, gridOrdering, computeEnergyAndVirial);
+                    pme_gpu_solve(pme->gpu, gridIndex, h_grid, gridOrdering, computeEnergyAndVirial);
                     break;
 
-                default:
-                    GMX_THROW(InternalError("Test not implemented for this mode"));
+                default: GMX_THROW(InternalError("Test not implemented for this mode"));
             }
             break;
 
-        default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
     }
 }
 
 //! PME force gathering
-void pmePerformGather(gmx_pme_t *pme, CodePath mode,
-                      PmeForceOutputHandling inputTreatment, ForcesVector &forces)
+void pmePerformGather(gmx_pme_t* pme, CodePath mode, ForcesVector& forces)
 {
-    PmeAtomComm    *atc                     = &(pme->atc[0]);
-    const index     atomCount               = atc->numAtoms();
+    PmeAtomComm* atc       = &(pme->atc[0]);
+    const index  atomCount = atc->numAtoms();
     GMX_RELEASE_ASSERT(forces.ssize() == atomCount, "Invalid force buffer size");
-    const bool      forceReductionWithInput = (inputTreatment == PmeForceOutputHandling::ReduceWithInput);
-    const real      scale                   = 1.0;
-    const size_t    threadIndex             = 0;
-    const size_t    gridIndex               = 0;
-    real           *pmegrid                 = pme->pmegrid[gridIndex].grid.grid;
-    real           *fftgrid                 = pme->fftgrid[gridIndex];
+    const real   scale       = 1.0;
+    const size_t threadIndex = 0;
+    const size_t gridIndex   = 0;
+    real*        pmegrid     = pme->pmegrid[gridIndex].grid.grid;
+    real**       fftgrid     = pme->fftgrid;
 
     switch (mode)
     {
@@ -418,57 +460,202 @@ void pmePerformGather(gmx_pme_t *pme, CodePath mode,
                 // something which is normally done in serial spline computation (make_thread_local_ind())
                 atc->spline[threadIndex].n = atomCount;
             }
-            copy_fftgrid_to_pmegrid(pme, fftgrid, pmegrid, gridIndex, pme->nthread, threadIndex);
+            copy_fftgrid_to_pmegrid(pme, fftgrid[gridIndex], pmegrid, gridIndex, pme->nthread, threadIndex);
             unwrap_periodic_pmegrid(pme, pmegrid);
-            gather_f_bsplines(pme, pmegrid, !forceReductionWithInput, atc, &atc->spline[threadIndex], scale);
+            gather_f_bsplines(pme, pmegrid, true, atc, &atc->spline[threadIndex], scale);
             break;
 
+/* The compiler will complain about passing fftgrid (converting double ** to float **) if using
+ * double precision. GPUs are not used with double precision anyhow. */
+#if !GMX_DOUBLE
         case CodePath::GPU:
         {
             // Variable initialization needs a non-switch scope
-            PmeOutput output = pme_gpu_getOutput(*pme, GMX_PME_CALC_F);
-            GMX_ASSERT(forces.size() == output.forces_.size(), "Size of force buffers did not match");
-            if (forceReductionWithInput)
-            {
-                std::copy(std::begin(forces), std::end(forces), std::begin(output.forces_));
-            }
-            pme_gpu_gather(pme->gpu, inputTreatment, reinterpret_cast<float *>(fftgrid));
+            const bool computeEnergyAndVirial = false;
+            const real lambdaQ                = 1.0;
+            PmeOutput  output = pme_gpu_getOutput(*pme, computeEnergyAndVirial, lambdaQ);
+            GMX_ASSERT(forces.size() == output.forces_.size(),
+                       "Size of force buffers did not match");
+            pme_gpu_gather(pme->gpu, fftgrid, lambdaQ);
             std::copy(std::begin(output.forces_), std::end(output.forces_), std::begin(forces));
         }
         break;
+#endif
 
-        default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
     }
 }
 
 //! PME test finalization before fetching the outputs
-void pmeFinalizeTest(const gmx_pme_t *pme, CodePath mode)
+void pmeFinalizeTest(const gmx_pme_tpme, CodePath mode)
 {
     switch (mode)
     {
-        case CodePath::CPU:
+        case CodePath::CPU: break;
+
+        case CodePath::GPU: pme_gpu_synchronize(pme->gpu); break;
+
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
+    }
+}
+
+//! A binary enum for spline data layout transformation
+enum class PmeLayoutTransform
+{
+    GpuToHost,
+    HostToGpu
+};
+
+/*! \brief Gets a unique index to an element in a spline parameter buffer.
+ *
+ * These theta/dtheta buffers are laid out for GPU spread/gather
+ * kernels. The index is wrt the execution block, in range(0,
+ * atomsPerBlock * order * DIM).
+ *
+ * This is a wrapper, only used in unit tests.
+ * \param[in] order            PME order
+ * \param[in] splineIndex      Spline contribution index (from 0 to \p order - 1)
+ * \param[in] dimIndex         Dimension index (from 0 to 2)
+ * \param[in] atomIndex        Atom index wrt the block.
+ * \param[in] atomsPerWarp     Number of atoms processed by a warp.
+ *
+ * \returns Index into theta or dtheta array using GPU layout.
+ */
+static int getSplineParamFullIndex(int order, int splineIndex, int dimIndex, int atomIndex, int atomsPerWarp)
+{
+    if (order != c_pmeGpuOrder)
+    {
+        throw order;
+    }
+    constexpr int fixedOrder = c_pmeGpuOrder;
+    GMX_UNUSED_VALUE(fixedOrder);
+
+    const int atomWarpIndex = atomIndex % atomsPerWarp;
+    const int warpIndex     = atomIndex / atomsPerWarp;
+    int       indexBase, result;
+    switch (atomsPerWarp)
+    {
+        case 1:
+            indexBase = getSplineParamIndexBase<fixedOrder, 1>(warpIndex, atomWarpIndex);
+            result    = getSplineParamIndex<fixedOrder, 1>(indexBase, dimIndex, splineIndex);
             break;
 
-        case CodePath::GPU:
-            pme_gpu_synchronize(pme->gpu);
+        case 2:
+            indexBase = getSplineParamIndexBase<fixedOrder, 2>(warpIndex, atomWarpIndex);
+            result    = getSplineParamIndex<fixedOrder, 2>(indexBase, dimIndex, splineIndex);
+            break;
+
+        case 4:
+            indexBase = getSplineParamIndexBase<fixedOrder, 4>(warpIndex, atomWarpIndex);
+            result    = getSplineParamIndex<fixedOrder, 4>(indexBase, dimIndex, splineIndex);
+            break;
+
+        case 8:
+            indexBase = getSplineParamIndexBase<fixedOrder, 8>(warpIndex, atomWarpIndex);
+            result    = getSplineParamIndex<fixedOrder, 8>(indexBase, dimIndex, splineIndex);
             break;
 
         default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+            GMX_THROW(NotImplementedError(
+                    formatString("Test function call not unrolled for atomsPerWarp = %d in "
+                                 "getSplineParamFullIndex",
+                                 atomsPerWarp)));
     }
+    return result;
 }
 
-//! Setting atom spline values/derivatives to be used in spread/gather
-void pmeSetSplineData(const gmx_pme_t *pme, CodePath mode,
-                      const SplineParamsDimVector &splineValues, PmeSplineDataType type, int dimIndex)
+/*!\brief Return the number of atoms per warp */
+static int pme_gpu_get_atoms_per_warp(const PmeGpu* pmeGpu)
 {
-    const PmeAtomComm    *atc         = &(pme->atc[0]);
-    const index           atomCount   = atc->numAtoms();
-    const index           pmeOrder    = pme->pme_order;
-    const index           dimSize     = pmeOrder * atomCount;
+    const int order = pmeGpu->common->pme_order;
+    const int threadsPerAtom =
+            (pmeGpu->settings.threadsPerAtom == ThreadsPerAtom::Order ? order : order * order);
+    return pmeGpu->programHandle_->warpSize() / threadsPerAtom;
+}
+
+/*! \brief Rearranges the atom spline data between the GPU and host layouts.
+ * Only used for test purposes so far, likely to be horribly slow.
+ *
+ * \param[in]  pmeGpu     The PME GPU structure.
+ * \param[out] atc        The PME CPU atom data structure (with a single-threaded layout).
+ * \param[in]  type       The spline data type (values or derivatives).
+ * \param[in]  dimIndex   Dimension index.
+ * \param[in]  transform  Layout transform type
+ */
+static void pme_gpu_transform_spline_atom_data(const PmeGpu*      pmeGpu,
+                                               const PmeAtomComm* atc,
+                                               PmeSplineDataType  type,
+                                               int                dimIndex,
+                                               PmeLayoutTransform transform)
+{
+    // The GPU atom spline data is laid out in a different way currently than the CPU one.
+    // This function converts the data from GPU to CPU layout (in the host memory).
+    // It is only intended for testing purposes so far.
+    // Ideally we should use similar layouts on CPU and GPU if we care about mixed modes and their
+    // performance (e.g. spreading on GPU, gathering on CPU).
+    GMX_RELEASE_ASSERT(atc->nthread == 1, "Only the serial PME data layout is supported");
+    const uintmax_t threadIndex  = 0;
+    const auto      atomCount    = atc->numAtoms();
+    const auto      atomsPerWarp = pme_gpu_get_atoms_per_warp(pmeGpu);
+    const auto      pmeOrder     = pmeGpu->common->pme_order;
+    GMX_ASSERT(pmeOrder == c_pmeGpuOrder, "Only PME order 4 is implemented");
+
+    real*  cpuSplineBuffer;
+    float* h_splineBuffer;
+    switch (type)
+    {
+        case PmeSplineDataType::Values:
+            cpuSplineBuffer = atc->spline[threadIndex].theta.coefficients[dimIndex];
+            h_splineBuffer  = pmeGpu->staging.h_theta;
+            break;
+
+        case PmeSplineDataType::Derivatives:
+            cpuSplineBuffer = atc->spline[threadIndex].dtheta.coefficients[dimIndex];
+            h_splineBuffer  = pmeGpu->staging.h_dtheta;
+            break;
+
+        default: GMX_THROW(InternalError("Unknown spline data type"));
+    }
+
+    for (auto atomIndex = 0; atomIndex < atomCount; atomIndex++)
+    {
+        for (auto orderIndex = 0; orderIndex < pmeOrder; orderIndex++)
+        {
+            const auto gpuValueIndex =
+                    getSplineParamFullIndex(pmeOrder, orderIndex, dimIndex, atomIndex, atomsPerWarp);
+            const auto cpuValueIndex = atomIndex * pmeOrder + orderIndex;
+            GMX_ASSERT(cpuValueIndex < atomCount * pmeOrder,
+                       "Atom spline data index out of bounds (while transforming GPU data layout "
+                       "for host)");
+            switch (transform)
+            {
+                case PmeLayoutTransform::GpuToHost:
+                    cpuSplineBuffer[cpuValueIndex] = h_splineBuffer[gpuValueIndex];
+                    break;
+
+                case PmeLayoutTransform::HostToGpu:
+                    h_splineBuffer[gpuValueIndex] = cpuSplineBuffer[cpuValueIndex];
+                    break;
+
+                default: GMX_THROW(InternalError("Unknown layout transform"));
+            }
+        }
+    }
+}
+
+//! Setting atom spline values/derivatives to be used in spread/gather
+void pmeSetSplineData(const gmx_pme_t*             pme,
+                      CodePath                     mode,
+                      const SplineParamsDimVector& splineValues,
+                      PmeSplineDataType            type,
+                      int                          dimIndex)
+{
+    const PmeAtomComm* atc       = &(pme->atc[0]);
+    const index        atomCount = atc->numAtoms();
+    const index        pmeOrder  = pme->pme_order;
+    const index        dimSize   = pmeOrder * atomCount;
     GMX_RELEASE_ASSERT(dimSize == splineValues.ssize(), "Mismatch in spline data");
-    real                 *splineBuffer = pmeGetSplineDataInternal(pme, type, dimIndex);
+    realsplineBuffer = pmeGetSplineDataInternal(pme, type, dimIndex);
 
     switch (mode)
     {
@@ -481,47 +668,47 @@ void pmeSetSplineData(const gmx_pme_t *pme, CodePath mode,
             pme_gpu_transform_spline_atom_data(pme->gpu, atc, type, dimIndex, PmeLayoutTransform::HostToGpu);
             break;
 
-        default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
     }
 }
 
 //! Setting gridline indices to be used in spread/gather
-void pmeSetGridLineIndices(gmx_pme_t *pme, CodePath mode,
-                           const GridLineIndicesVector &gridLineIndices)
+void pmeSetGridLineIndices(gmx_pme_t* pme, CodePath mode, const GridLineIndicesVector& gridLineIndices)
 {
-    PmeAtomComm                *atc         = &(pme->atc[0]);
-    const index                 atomCount   = atc->numAtoms();
+    PmeAtomComm* atc       = &(pme->atc[0]);
+    const index  atomCount = atc->numAtoms();
     GMX_RELEASE_ASSERT(atomCount == gridLineIndices.ssize(), "Mismatch in gridline indices size");
 
     IVec paddedGridSizeUnused, gridSize(0, 0, 0);
     pmeGetRealGridSizesInternal(pme, mode, gridSize, paddedGridSizeUnused);
 
-    for (const auto &index : gridLineIndices)
+    for (const autoindex : gridLineIndices)
     {
         for (int i = 0; i < DIM; i++)
         {
-            GMX_RELEASE_ASSERT((0 <= index[i]) && (index[i] < gridSize[i]), "Invalid gridline index");
+            GMX_RELEASE_ASSERT((0 <= index[i]) && (index[i] < gridSize[i]),
+                               "Invalid gridline index");
         }
     }
 
     switch (mode)
     {
         case CodePath::GPU:
-            memcpy(pme->gpu->staging.h_gridlineIndices, gridLineIndices.data(), atomCount * sizeof(gridLineIndices[0]));
+            memcpy(pme_gpu_staging(pme->gpu).h_gridlineIndices,
+                   gridLineIndices.data(),
+                   atomCount * sizeof(gridLineIndices[0]));
             break;
 
         case CodePath::CPU:
             atc->idx.resize(gridLineIndices.size());
             std::copy(gridLineIndices.begin(), gridLineIndices.end(), atc->idx.begin());
             break;
-        default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
     }
 }
 
 //! Getting plain index into the complex 3d grid
-inline size_t pmeGetGridPlainIndexInternal(const IVec &index, const IVec &paddedGridSize, GridOrdering gridOrdering)
+inline size_t pmeGetGridPlainIndexInternal(const IVec& index, const IVec& paddedGridSize, GridOrdering gridOrdering)
 {
     size_t result;
     switch (gridOrdering)
@@ -534,20 +721,20 @@ inline size_t pmeGetGridPlainIndexInternal(const IVec &index, const IVec &padded
             result = (index[XX] * paddedGridSize[YY] + index[YY]) * paddedGridSize[ZZ] + index[ZZ];
             break;
 
-        default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
     }
     return result;
 }
 
 //! Setting real or complex grid
 template<typename ValueType>
-static void pmeSetGridInternal(const gmx_pme_t *pme, CodePath mode,
-                               GridOrdering gridOrdering,
-                               const SparseGridValuesInput<ValueType> &gridValues)
+static void pmeSetGridInternal(const gmx_pme_t*                        pme,
+                               CodePath                                mode,
+                               GridOrdering                            gridOrdering,
+                               const SparseGridValuesInput<ValueType>& gridValues)
 {
     IVec       gridSize(0, 0, 0), paddedGridSize(0, 0, 0);
-    ValueType *grid;
+    ValueTypegrid;
     pmeGetGridAndSizesInternal<ValueType>(pme, mode, grid, gridSize, paddedGridSize);
 
     switch (mode)
@@ -555,49 +742,49 @@ static void pmeSetGridInternal(const gmx_pme_t *pme, CodePath mode,
         case CodePath::GPU: // intentional absence of break, the grid will be copied from the host buffer in testing mode
         case CodePath::CPU:
             std::memset(grid, 0, paddedGridSize[XX] * paddedGridSize[YY] * paddedGridSize[ZZ] * sizeof(ValueType));
-            for (const auto &gridValue : gridValues)
+            for (const autogridValue : gridValues)
             {
                 for (int i = 0; i < DIM; i++)
                 {
-                    GMX_RELEASE_ASSERT((0 <= gridValue.first[i]) && (gridValue.first[i] < gridSize[i]), "Invalid grid value index");
+                    GMX_RELEASE_ASSERT((0 <= gridValue.first[i]) && (gridValue.first[i] < gridSize[i]),
+                                       "Invalid grid value index");
                 }
-                const size_t gridValueIndex = pmeGetGridPlainIndexInternal(gridValue.first, paddedGridSize, gridOrdering);
+                const size_t gridValueIndex =
+                        pmeGetGridPlainIndexInternal(gridValue.first, paddedGridSize, gridOrdering);
                 grid[gridValueIndex] = gridValue.second;
             }
             break;
 
-        default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
     }
 }
 
 //! Setting real grid to be used in gather
-void pmeSetRealGrid(const gmx_pme_t *pme, CodePath mode,
-                    const SparseRealGridValuesInput &gridValues)
+void pmeSetRealGrid(const gmx_pme_t* pme, CodePath mode, const SparseRealGridValuesInput& gridValues)
 {
     pmeSetGridInternal<real>(pme, mode, GridOrdering::XYZ, gridValues);
 }
 
 //! Setting complex grid to be used in solve
-void pmeSetComplexGrid(const gmx_pme_t *pme, CodePath mode,
-                       GridOrdering gridOrdering,
-                       const SparseComplexGridValuesInput &gridValues)
+void pmeSetComplexGrid(const gmx_pme_t*                    pme,
+                       CodePath                            mode,
+                       GridOrdering                        gridOrdering,
+                       const SparseComplexGridValuesInput& gridValues)
 {
     pmeSetGridInternal<t_complex>(pme, mode, gridOrdering, gridValues);
 }
 
 //! Getting the single dimension's spline values or derivatives
-SplineParamsDimVector pmeGetSplineData(const gmx_pme_t *pme, CodePath mode,
-                                       PmeSplineDataType type, int dimIndex)
+SplineParamsDimVector pmeGetSplineData(const gmx_pme_t* pme, CodePath mode, PmeSplineDataType type, int dimIndex)
 {
     GMX_RELEASE_ASSERT(pme != nullptr, "PME data is not initialized");
-    const PmeAtomComm       *atc         = &(pme->atc[0]);
-    const size_t             atomCount   = atc->numAtoms();
-    const size_t             pmeOrder    = pme->pme_order;
-    const size_t             dimSize     = pmeOrder * atomCount;
+    const PmeAtomComm* atc       = &(pme->atc[0]);
+    const size_t       atomCount = atc->numAtoms();
+    const size_t       pmeOrder  = pme->pme_order;
+    const size_t       dimSize   = pmeOrder * atomCount;
 
-    real                    *sourceBuffer = pmeGetSplineDataInternal(pme, type, dimIndex);
-    SplineParamsDimVector    result;
+    real*                 sourceBuffer = pmeGetSplineDataInternal(pme, type, dimIndex);
+    SplineParamsDimVector result;
     switch (mode)
     {
         case CodePath::GPU:
@@ -605,46 +792,43 @@ SplineParamsDimVector pmeGetSplineData(const gmx_pme_t *pme, CodePath mode,
             result = arrayRefFromArray(sourceBuffer, dimSize);
             break;
 
-        case CodePath::CPU:
-            result = arrayRefFromArray(sourceBuffer, dimSize);
-            break;
+        case CodePath::CPU: result = arrayRefFromArray(sourceBuffer, dimSize); break;
 
-        default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
     }
     return result;
 }
 
 //! Getting the gridline indices
-GridLineIndicesVector pmeGetGridlineIndices(const gmx_pme_t *pme, CodePath mode)
+GridLineIndicesVector pmeGetGridlineIndices(const gmx_pme_tpme, CodePath mode)
 {
     GMX_RELEASE_ASSERT(pme != nullptr, "PME data is not initialized");
-    const PmeAtomComm    *atc         = &(pme->atc[0]);
-    const size_t          atomCount   = atc->numAtoms();
+    const PmeAtomComm* atc       = &(pme->atc[0]);
+    const size_t       atomCount = atc->numAtoms();
 
     GridLineIndicesVector gridLineIndices;
     switch (mode)
     {
         case CodePath::GPU:
-            gridLineIndices = arrayRefFromArray(reinterpret_cast<IVec *>(pme->gpu->staging.h_gridlineIndices), atomCount);
+            gridLineIndices = arrayRefFromArray(
+                    reinterpret_cast<IVec*>(pme_gpu_staging(pme->gpu).h_gridlineIndices), atomCount);
             break;
 
-        case CodePath::CPU:
-            gridLineIndices = atc->idx;
-            break;
+        case CodePath::CPU: gridLineIndices = atc->idx; break;
 
-        default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
     }
     return gridLineIndices;
 }
 
 //! Getting real or complex grid - only non zero values
 template<typename ValueType>
-static SparseGridValuesOutput<ValueType> pmeGetGridInternal(const gmx_pme_t *pme, CodePath mode, GridOrdering gridOrdering)
+static SparseGridValuesOutput<ValueType> pmeGetGridInternal(const gmx_pme_t* pme,
+                                                            CodePath         mode,
+                                                            GridOrdering     gridOrdering)
 {
     IVec       gridSize(0, 0, 0), paddedGridSize(0, 0, 0);
-    ValueType *grid;
+    ValueTypegrid;
     pmeGetGridAndSizesInternal<ValueType>(pme, mode, grid, gridSize, paddedGridSize);
     SparseGridValuesOutput<ValueType> gridValues;
     switch (mode)
@@ -658,12 +842,13 @@ static SparseGridValuesOutput<ValueType> pmeGetGridInternal(const gmx_pme_t *pme
                 {
                     for (int iz = 0; iz < gridSize[ZZ]; iz++)
                     {
-                        IVec            temp(ix, iy, iz);
-                        const size_t    gridValueIndex = pmeGetGridPlainIndexInternal(temp, paddedGridSize, gridOrdering);
-                        const ValueType value          = grid[gridValueIndex];
-                        if (value != ValueType {})
+                        IVec         temp(ix, iy, iz);
+                        const size_t gridValueIndex =
+                                pmeGetGridPlainIndexInternal(temp, paddedGridSize, gridOrdering);
+                        const ValueType value = grid[gridValueIndex];
+                        if (value != ValueType{})
                         {
-                            auto key = formatString("Cell %d %d %d", ix, iy, iz);
+                            auto key        = formatString("Cell %d %d %d", ix, iy, iz);
                             gridValues[key] = value;
                         }
                     }
@@ -671,30 +856,28 @@ static SparseGridValuesOutput<ValueType> pmeGetGridInternal(const gmx_pme_t *pme
             }
             break;
 
-        default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
     }
     return gridValues;
 }
 
 //! Getting the real grid (spreading output of pmePerformSplineAndSpread())
-SparseRealGridValuesOutput pmeGetRealGrid(const gmx_pme_t *pme, CodePath mode)
+SparseRealGridValuesOutput pmeGetRealGrid(const gmx_pme_tpme, CodePath mode)
 {
     return pmeGetGridInternal<real>(pme, mode, GridOrdering::XYZ);
 }
 
 //! Getting the complex grid output of pmePerformSolve()
-SparseComplexGridValuesOutput pmeGetComplexGrid(const gmx_pme_t *pme, CodePath mode,
-                                                GridOrdering gridOrdering)
+SparseComplexGridValuesOutput pmeGetComplexGrid(const gmx_pme_t* pme, CodePath mode, GridOrdering gridOrdering)
 {
     return pmeGetGridInternal<t_complex>(pme, mode, gridOrdering);
 }
 
 //! Getting the reciprocal energy and virial
-PmeOutput pmeGetReciprocalEnergyAndVirial(const gmx_pme_t *pme, CodePath mode,
-                                          PmeSolveAlgorithm method)
+PmeOutput pmeGetReciprocalEnergyAndVirial(const gmx_pme_t* pme, CodePath mode, PmeSolveAlgorithm method)
 {
-    PmeOutput output;
+    PmeOutput  output;
+    const real lambdaQ = 1.0;
     switch (mode)
     {
         case CodePath::CPU:
@@ -708,27 +891,76 @@ PmeOutput pmeGetReciprocalEnergyAndVirial(const gmx_pme_t *pme, CodePath mode,
                     get_pme_ener_vir_lj(pme->solve_work, pme->nthread, &output);
                     break;
 
-                default:
-                    GMX_THROW(InternalError("Test not implemented for this mode"));
+                default: GMX_THROW(InternalError("Test not implemented for this mode"));
             }
             break;
         case CodePath::GPU:
             switch (method)
             {
                 case PmeSolveAlgorithm::Coulomb:
-                    pme_gpu_getEnergyAndVirial(*pme, &output);
+                    pme_gpu_getEnergyAndVirial(*pme, lambdaQ, &output);
                     break;
 
-                default:
-                    GMX_THROW(InternalError("Test not implemented for this mode"));
+                default: GMX_THROW(InternalError("Test not implemented for this mode"));
             }
             break;
 
-        default:
-            GMX_THROW(InternalError("Test not implemented for this mode"));
+        default: GMX_THROW(InternalError("Test not implemented for this mode"));
     }
     return output;
 }
 
-}  // namespace test
-}  // namespace gmx
+const char* codePathToString(CodePath codePath)
+{
+    switch (codePath)
+    {
+        case CodePath::CPU: return "CPU";
+        case CodePath::GPU: return "GPU";
+        default: GMX_THROW(NotImplementedError("This CodePath should support codePathToString"));
+    }
+}
+
+PmeTestHardwareContext::PmeTestHardwareContext() : codePath_(CodePath::CPU) {}
+
+PmeTestHardwareContext::PmeTestHardwareContext(TestDevice* testDevice) :
+    codePath_(CodePath::GPU), testDevice_(testDevice)
+{
+    setActiveDevice(testDevice_->deviceInfo());
+    pmeGpuProgram_ = buildPmeGpuProgram(testDevice_->deviceContext());
+}
+
+//! Returns a human-readable context description line
+std::string PmeTestHardwareContext::description() const
+{
+    switch (codePath_)
+    {
+        case CodePath::CPU: return "CPU";
+        case CodePath::GPU: return "GPU (" + testDevice_->description() + ")";
+        default: return "Unknown code path.";
+    }
+}
+
+void PmeTestHardwareContext::activate() const
+{
+    if (codePath_ == CodePath::GPU)
+    {
+        setActiveDevice(testDevice_->deviceInfo());
+    }
+}
+
+std::vector<std::unique_ptr<PmeTestHardwareContext>> createPmeTestHardwareContextList()
+{
+    std::vector<std::unique_ptr<PmeTestHardwareContext>> pmeTestHardwareContextList;
+    // Add CPU
+    pmeTestHardwareContextList.emplace_back(std::make_unique<PmeTestHardwareContext>());
+    // Add GPU devices
+    const auto& testDeviceList = getTestHardwareEnvironment()->getTestDeviceList();
+    for (const auto& testDevice : testDeviceList)
+    {
+        pmeTestHardwareContextList.emplace_back(std::make_unique<PmeTestHardwareContext>(testDevice.get()));
+    }
+    return pmeTestHardwareContextList;
+}
+
+} // namespace test
+} // namespace gmx