SIMD support for nonbonded free-energy kernels
[alexxy/gromacs.git] / src / gromacs / gmxlib / nonbonded / nb_free_energy.cpp
index d73bc7fb0eca796d1fe8b57cbf7160d3f497b34e..45e09c25bc9169a97087f59b92e16f271236d0b1 100644 (file)
 #include "config.h"
 
 #include <cmath>
+#include <set>
 
 #include <algorithm>
 
 #include "gromacs/gmxlib/nrnb.h"
 #include "gromacs/gmxlib/nonbonded/nonbonded.h"
+#include "gromacs/math/arrayrefwithpadding.h"
 #include "gromacs/math/functions.h"
 #include "gromacs/math/vec.h"
 #include "gromacs/mdtypes/forceoutput.h"
@@ -55,7 +57,9 @@
 #include "gromacs/mdtypes/md_enums.h"
 #include "gromacs/mdtypes/mdatom.h"
 #include "gromacs/mdtypes/nblist.h"
+#include "gromacs/pbcutil/ishift.h"
 #include "gromacs/simd/simd.h"
+#include "gromacs/simd/simd_math.h"
 #include "gromacs/utility/fatalerror.h"
 #include "gromacs/utility/arrayref.h"
 
 //! Scalar (non-SIMD) data types.
 struct ScalarDataTypes
 {
-    using RealType                     = real; //!< The data type to use as real.
-    using IntType                      = int;  //!< The data type to use as int.
-    static constexpr int simdRealWidth = 1;    //!< The width of the RealType.
-    static constexpr int simdIntWidth  = 1;    //!< The width of the IntType.
+    using RealType = real; //!< The data type to use as real.
+    using IntType  = int;  //!< The data type to use as int.
+    using BoolType = bool; //!< The data type to use as bool for real value comparison.
+    static constexpr int simdRealWidth = 1; //!< The width of the RealType.
+    static constexpr int simdIntWidth  = 1; //!< The width of the IntType.
 };
 
 #if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS
 //! SIMD data types.
 struct SimdDataTypes
 {
-    using RealType                     = gmx::SimdReal;         //!< The data type to use as real.
-    using IntType                      = gmx::SimdInt32;        //!< The data type to use as int.
-    static constexpr int simdRealWidth = GMX_SIMD_REAL_WIDTH;   //!< The width of the RealType.
-    static constexpr int simdIntWidth  = GMX_SIMD_FINT32_WIDTH; //!< The width of the IntType.
+    using RealType = gmx::SimdReal;  //!< The data type to use as real.
+    using IntType  = gmx::SimdInt32; //!< The data type to use as int.
+    using BoolType = gmx::SimdBool;  //!< The data type to use as bool for real value comparison.
+    static constexpr int simdRealWidth = GMX_SIMD_REAL_WIDTH; //!< The width of the RealType.
+#    if GMX_SIMD_HAVE_DOUBLE && GMX_DOUBLE
+    static constexpr int simdIntWidth = GMX_SIMD_DINT32_WIDTH; //!< The width of the IntType.
+#    else
+    static constexpr int simdIntWidth = GMX_SIMD_FINT32_WIDTH; //!< The width of the IntType.
+#    endif
 };
 #endif
 
+template<class RealType, class BoolType>
+static inline void
+pmeCoulombCorrectionVF(const RealType rSq, const real beta, RealType* pot, RealType* force, const BoolType mask)
+{
+    const RealType brsq = gmx::selectByMask(rSq * beta * beta, mask);
+    *force              = -brsq * beta * gmx::pmeForceCorrection(brsq);
+    *pot                = beta * gmx::pmePotentialCorrection(brsq);
+}
+
+template<class RealType, class BoolType>
+static inline void pmeLJCorrectionVF(const RealType rInv,
+                                     const RealType rSq,
+                                     const real     ewaldLJCoeffSq,
+                                     const real     ewaldLJCoeffSixDivSix,
+                                     RealType*      pot,
+                                     RealType*      force,
+                                     const BoolType mask,
+                                     const BoolType bIiEqJnr)
+{
+    // We mask rInv to get zero force and potential for masked out pair interactions
+    const RealType rInvSq  = gmx::selectByMask(rInv * rInv, mask);
+    const RealType rInvSix = rInvSq * rInvSq * rInvSq;
+    // Mask rSq to avoid underflow in exp()
+    const RealType coeffSqRSq       = ewaldLJCoeffSq * gmx::selectByMask(rSq, mask);
+    const RealType expNegCoeffSqRSq = gmx::exp(-coeffSqRSq);
+    const RealType poly             = 1.0_real + coeffSqRSq + 0.5_real * coeffSqRSq * coeffSqRSq;
+    *force = rInvSix - expNegCoeffSqRSq * (rInvSix * poly + ewaldLJCoeffSixDivSix);
+    *force = *force * rInvSq;
+    // The self interaction is the limit for r -> 0 which we need to compute separately
+    *pot = gmx::blend(
+            rInvSix * (1.0_real - expNegCoeffSqRSq * poly), 0.5_real * ewaldLJCoeffSixDivSix, bIiEqJnr);
+}
+
 //! Computes r^(1/p) and 1/r^(1/p) for the standard p=6
-template<class RealType>
-static inline void pthRoot(const RealType r, RealType* pthRoot, RealType* invPthRoot)
+template<class RealType, class BoolType>
+static inline void pthRoot(const RealType r, RealType* pthRoot, RealType* invPthRoot, const BoolType mask)
 {
-    *invPthRoot = gmx::invsqrt(std::cbrt(r));
-    *pthRoot    = 1 / (*invPthRoot);
+    RealType cbrtRes = gmx::cbrt(r);
+    *invPthRoot      = gmx::maskzInvsqrt(cbrtRes, mask);
+    *pthRoot         = gmx::maskzInv(*invPthRoot, mask);
 }
 
 template<class RealType>
@@ -159,65 +203,56 @@ static inline RealType lennardJonesPotential(const RealType v6,
 }
 
 /* Ewald LJ */
-static inline real ewaldLennardJonesGridSubtract(const real c6grid, const real potentialShift, const real oneSixth)
+template<class RealType>
+static inline RealType ewaldLennardJonesGridSubtract(const RealType c6grid,
+                                                     const real     potentialShift,
+                                                     const real     oneSixth)
 {
     return (c6grid * potentialShift * oneSixth);
 }
 
 /* LJ Potential switch */
-template<class RealType>
+template<class RealType, class BoolType>
 static inline RealType potSwitchScalarForceMod(const RealType fScalarInp,
                                                const RealType potential,
                                                const RealType sw,
                                                const RealType r,
-                                               const RealType rVdw,
                                                const RealType dsw,
-                                               const real     zero)
+                                               const BoolType mask)
 {
-    if (r < rVdw)
-    {
-        real fScalar = fScalarInp * sw - r * potential * dsw;
-        return (fScalar);
-    }
-    return (zero);
+    /* The mask should select on rV < rVdw */
+    return (gmx::selectByMask(fScalarInp * sw - r * potential * dsw, mask));
 }
-template<class RealType>
-static inline RealType potSwitchPotentialMod(const RealType potentialInp,
-                                             const RealType sw,
-                                             const RealType r,
-                                             const RealType rVdw,
-                                             const real     zero)
+template<class RealType, class BoolType>
+static inline RealType potSwitchPotentialMod(const RealType potentialInp, const RealType sw, const BoolType mask)
 {
-    if (r < rVdw)
-    {
-        real potential = potentialInp * sw;
-        return (potential);
-    }
-    return (zero);
+    /* The mask should select on rV < rVdw */
+    return (gmx::selectByMask(potentialInp * sw, mask));
 }
 
 
 //! Templated free-energy non-bonded kernel
 template<typename DataTypes, bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
-static void nb_free_energy_kernel(const t_nblist&                nlist,
-                                  gmx::ArrayRef<const gmx::RVec> coords,
-                                  gmx::ForceWithShiftForces*     forceWithShiftForces,
-                                  const int                      ntype,
-                                  const real                     rlist,
-                                  const interaction_const_t&     ic,
-                                  gmx::ArrayRef<const gmx::RVec> shiftvec,
-                                  gmx::ArrayRef<const real>      nbfp,
-                                  gmx::ArrayRef<const real>      nbfp_grid,
-                                  gmx::ArrayRef<const real>      chargeA,
-                                  gmx::ArrayRef<const real>      chargeB,
-                                  gmx::ArrayRef<const int>       typeA,
-                                  gmx::ArrayRef<const int>       typeB,
-                                  int                            flags,
-                                  gmx::ArrayRef<const real>      lambda,
-                                  gmx::ArrayRef<real>            dvdl,
-                                  gmx::ArrayRef<real>            energygrp_elec,
-                                  gmx::ArrayRef<real>            energygrp_vdw,
-                                  t_nrnb* gmx_restrict           nrnb)
+static void nb_free_energy_kernel(const t_nblist&                           nlist,
+                                  const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
+                                  const int                                 ntype,
+                                  const real                                rlist,
+                                  const interaction_const_t&                ic,
+                                  gmx::ArrayRef<const gmx::RVec>            shiftvec,
+                                  gmx::ArrayRef<const real>                 nbfp,
+                                  gmx::ArrayRef<const real> gmx_unused      nbfp_grid,
+                                  gmx::ArrayRef<const real>                 chargeA,
+                                  gmx::ArrayRef<const real>                 chargeB,
+                                  gmx::ArrayRef<const int>                  typeA,
+                                  gmx::ArrayRef<const int>                  typeB,
+                                  int                                       flags,
+                                  gmx::ArrayRef<const real>                 lambda,
+                                  t_nrnb* gmx_restrict                      nrnb,
+                                  gmx::RVec*                                threadForceBuffer,
+                                  rvec*                                     threadForceShiftBuffer,
+                                  gmx::ArrayRef<real>                       threadVc,
+                                  gmx::ArrayRef<real>                       threadVv,
+                                  gmx::ArrayRef<real>                       threadDvdl)
 {
 #define STATE_A 0
 #define STATE_B 1
@@ -225,15 +260,15 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
 
     using RealType = typename DataTypes::RealType;
     using IntType  = typename DataTypes::IntType;
+    using BoolType = typename DataTypes::BoolType;
 
-    /* FIXME: How should these be handled with SIMD? */
-    constexpr real oneTwelfth = 1.0 / 12.0;
-    constexpr real oneSixth   = 1.0 / 6.0;
-    constexpr real zero       = 0.0;
-    constexpr real half       = 0.5;
-    constexpr real one        = 1.0;
-    constexpr real two        = 2.0;
-    constexpr real six        = 6.0;
+    constexpr real oneTwelfth = 1.0_real / 12.0_real;
+    constexpr real oneSixth   = 1.0_real / 6.0_real;
+    constexpr real zero       = 0.0_real;
+    constexpr real half       = 0.5_real;
+    constexpr real one        = 1.0_real;
+    constexpr real two        = 2.0_real;
+    constexpr real six        = 6.0_real;
 
     // Extract pair list data
     const int                nri    = nlist.nri;
@@ -243,47 +278,56 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
     gmx::ArrayRef<const int> shift  = nlist.shift;
     gmx::ArrayRef<const int> gid    = nlist.gid;
 
-    const real  lambda_coul   = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)];
-    const real  lambda_vdw    = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)];
-    const auto& scParams      = *ic.softCoreParameters;
-    const real  alpha_coul    = scParams.alphaCoulomb;
-    const real  alpha_vdw     = scParams.alphaVdw;
-    const real  lam_power     = scParams.lambdaPower;
-    const real  sigma6_def    = scParams.sigma6WithInvalidSigma;
-    const real  sigma6_min    = scParams.sigma6Minimum;
-    const bool  doForces      = ((flags & GMX_NONBONDED_DO_FORCE) != 0);
-    const bool  doShiftForces = ((flags & GMX_NONBONDED_DO_SHIFTFORCE) != 0);
-    const bool  doPotential   = ((flags & GMX_NONBONDED_DO_POTENTIAL) != 0);
+    const real  lambda_coul = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)];
+    const real  lambda_vdw  = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)];
+    const auto& scParams    = *ic.softCoreParameters;
+    const real gmx_unused alpha_coul    = scParams.alphaCoulomb;
+    const real gmx_unused alpha_vdw     = scParams.alphaVdw;
+    const real            lam_power     = scParams.lambdaPower;
+    const real gmx_unused sigma6_def    = scParams.sigma6WithInvalidSigma;
+    const real gmx_unused sigma6_min    = scParams.sigma6Minimum;
+    const bool            doForces      = ((flags & GMX_NONBONDED_DO_FORCE) != 0);
+    const bool            doShiftForces = ((flags & GMX_NONBONDED_DO_SHIFTFORCE) != 0);
+    const bool            doPotential   = ((flags & GMX_NONBONDED_DO_POTENTIAL) != 0);
 
     // Extract data from interaction_const_t
-    const real facel           = ic.epsfac;
-    const real rCoulomb        = ic.rcoulomb;
-    const real krf             = ic.reactionFieldCoefficient;
-    const real crf             = ic.reactionFieldShift;
-    const real shLjEwald       = ic.sh_lj_ewald;
-    const real rVdw            = ic.rvdw;
-    const real dispersionShift = ic.dispersion_shift.cpot;
-    const real repulsionShift  = ic.repulsion_shift.cpot;
+    const real            facel           = ic.epsfac;
+    const real            rCoulomb        = ic.rcoulomb;
+    const real            krf             = ic.reactionFieldCoefficient;
+    const real            crf             = ic.reactionFieldShift;
+    const real gmx_unused shLjEwald       = ic.sh_lj_ewald;
+    const real            rVdw            = ic.rvdw;
+    const real            dispersionShift = ic.dispersion_shift.cpot;
+    const real            repulsionShift  = ic.repulsion_shift.cpot;
+    const real            ewaldBeta       = ic.ewaldcoeff_q;
+    real gmx_unused       ewaldLJCoeffSq;
+    real gmx_unused       ewaldLJCoeffSixDivSix;
+    if constexpr (vdwInteractionTypeIsEwald)
+    {
+        ewaldLJCoeffSq        = ic.ewaldcoeff_lj * ic.ewaldcoeff_lj;
+        ewaldLJCoeffSixDivSix = ewaldLJCoeffSq * ewaldLJCoeffSq * ewaldLJCoeffSq / six;
+    }
 
     // Note that the nbnxm kernels do not support Coulomb potential switching at all
     GMX_ASSERT(ic.coulomb_modifier != InteractionModifiers::PotSwitch,
                "Potential switching is not supported for Coulomb with FEP");
 
-    real vdw_swV3, vdw_swV4, vdw_swV5, vdw_swF2, vdw_swF3, vdw_swF4;
-    if (vdwModifierIsPotSwitch)
+    const real      rVdwSwitch = ic.rvdw_switch;
+    real gmx_unused vdw_swV3, vdw_swV4, vdw_swV5, vdw_swF2, vdw_swF3, vdw_swF4;
+    if constexpr (vdwModifierIsPotSwitch)
     {
-        const real d = ic.rvdw - ic.rvdw_switch;
-        vdw_swV3     = -10.0 / (d * d * d);
-        vdw_swV4     = 15.0 / (d * d * d * d);
-        vdw_swV5     = -6.0 / (d * d * d * d * d);
-        vdw_swF2     = -30.0 / (d * d * d);
-        vdw_swF3     = 60.0 / (d * d * d * d);
-        vdw_swF4     = -30.0 / (d * d * d * d * d);
+        const real d = rVdw - rVdwSwitch;
+        vdw_swV3     = -10.0_real / (d * d * d);
+        vdw_swV4     = 15.0_real / (d * d * d * d);
+        vdw_swV5     = -6.0_real / (d * d * d * d * d);
+        vdw_swF2     = -30.0_real / (d * d * d);
+        vdw_swF3     = 60.0_real / (d * d * d * d);
+        vdw_swF4     = -30.0_real / (d * d * d * d * d);
     }
     else
     {
         /* Avoid warnings from stupid compilers (looking at you, Clang!) */
-        vdw_swV3 = vdw_swV4 = vdw_swV5 = vdw_swF2 = vdw_swF3 = vdw_swF4 = 0.0;
+        vdw_swV3 = vdw_swV4 = vdw_swV5 = vdw_swF2 = vdw_swF3 = vdw_swF4 = zero;
     }
 
     NbkernelElecType icoul;
@@ -299,33 +343,11 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
     real rcutoff_max2 = std::max(ic.rcoulomb, ic.rvdw);
     rcutoff_max2      = rcutoff_max2 * rcutoff_max2;
 
-    const real* tab_ewald_F_lj           = nullptr;
-    const real* tab_ewald_V_lj           = nullptr;
-    const real* ewtab                    = nullptr;
-    real        coulombTableScale        = 0;
-    real        coulombTableScaleInvHalf = 0;
-    real        vdwTableScale            = 0;
-    real        vdwTableScaleInvHalf     = 0;
-    real        sh_ewald                 = 0;
-    if (elecInteractionTypeIsEwald || vdwInteractionTypeIsEwald)
+    real gmx_unused sh_ewald = 0;
+    if constexpr (elecInteractionTypeIsEwald || vdwInteractionTypeIsEwald)
     {
         sh_ewald = ic.sh_ewald;
     }
-    if (elecInteractionTypeIsEwald)
-    {
-        const auto& coulombTables = *ic.coulombEwaldTables;
-        ewtab                     = coulombTables.tableFDV0.data();
-        coulombTableScale         = coulombTables.scale;
-        coulombTableScaleInvHalf  = half / coulombTableScale;
-    }
-    if (vdwInteractionTypeIsEwald)
-    {
-        const auto& vdwTables = *ic.vdwEwaldTables;
-        tab_ewald_F_lj        = vdwTables.tableF.data();
-        tab_ewald_V_lj        = vdwTables.tableV.data();
-        vdwTableScale         = vdwTables.scale;
-        vdwTableScaleInvHalf  = half / vdwTableScale;
-    }
 
     /* For Ewald/PME interactions we cannot easily apply the soft-core component to
      * reciprocal space. When we use non-switched Ewald interactions, we
@@ -342,8 +364,8 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
     GMX_RELEASE_ASSERT(!(vdwInteractionTypeIsEwald && vdwModifierIsPotSwitch),
                        "Can not apply soft-core to switched Ewald potentials");
 
-    real dvdlCoul = 0;
-    real dvdlVdw  = 0;
+    RealType dvdlCoul(zero);
+    RealType dvdlVdw(zero);
 
     /* Lambda factor for state A, 1-lambda*/
     real LFC[NSTATES], LFV[NSTATES];
@@ -356,11 +378,11 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
 
     /*derivative of the lambda factor for state A and B */
     real DLF[NSTATES];
-    DLF[STATE_A] = -1;
-    DLF[STATE_B] = 1;
+    DLF[STATE_A] = -one;
+    DLF[STATE_B] = one;
 
-    real           lFacCoul[NSTATES], dlFacCoul[NSTATES], lFacVdw[NSTATES], dlFacVdw[NSTATES];
-    constexpr real sc_r_power = 6.0_real;
+    real gmx_unused lFacCoul[NSTATES], dlFacCoul[NSTATES], lFacVdw[NSTATES], dlFacVdw[NSTATES];
+    constexpr real  sc_r_power = six;
     for (int i = 0; i < NSTATES; i++)
     {
         lFacCoul[i]  = (lam_power == 2 ? (1 - LFC[i]) * (1 - LFC[i]) : (1 - LFC[i]));
@@ -370,57 +392,172 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
     }
 
     // TODO: We should get rid of using pointers to real
-    const real*        x      = coords[0];
-    real* gmx_restrict f      = &(forceWithShiftForces->force()[0][0]);
-    real* gmx_restrict fshift = &(forceWithShiftForces->shiftForces()[0][0]);
+    const real* gmx_restrict x = coords.paddedConstArrayRef().data()[0];
 
     const real rlistSquared = gmx::square(rlist);
 
-    int numExcludedPairsBeyondRlist = 0;
+    bool haveExcludedPairsBeyondRlist = false;
 
     for (int n = 0; n < nri; n++)
     {
-        int npair_within_cutoff = 0;
-
-        const int  is    = shift[n];
-        const int  is3   = DIM * is;
-        const real shX   = shiftvec[is][XX];
-        const real shY   = shiftvec[is][YY];
-        const real shZ   = shiftvec[is][ZZ];
-        const int  nj0   = jindex[n];
-        const int  nj1   = jindex[n + 1];
-        const int  ii    = iinr[n];
-        const int  ii3   = 3 * ii;
-        const real ix    = shX + x[ii3 + 0];
-        const real iy    = shY + x[ii3 + 1];
-        const real iz    = shZ + x[ii3 + 2];
-        const real iqA   = facel * chargeA[ii];
-        const real iqB   = facel * chargeB[ii];
-        const int  ntiA  = 2 * ntype * typeA[ii];
-        const int  ntiB  = 2 * ntype * typeB[ii];
-        real       vCTot = 0;
-        real       vVTot = 0;
-        real       fIX   = 0;
-        real       fIY   = 0;
-        real       fIZ   = 0;
-
-        for (int k = nj0; k < nj1; k++)
+        bool havePairsWithinCutoff = false;
+
+        const int  is   = shift[n];
+        const real shX  = shiftvec[is][XX];
+        const real shY  = shiftvec[is][YY];
+        const real shZ  = shiftvec[is][ZZ];
+        const int  nj0  = jindex[n];
+        const int  nj1  = jindex[n + 1];
+        const int  ii   = iinr[n];
+        const int  ii3  = 3 * ii;
+        const real ix   = shX + x[ii3 + 0];
+        const real iy   = shY + x[ii3 + 1];
+        const real iz   = shZ + x[ii3 + 2];
+        const real iqA  = facel * chargeA[ii];
+        const real iqB  = facel * chargeB[ii];
+        const int  ntiA = ntype * typeA[ii];
+        const int  ntiB = ntype * typeB[ii];
+        RealType   vCTot(0);
+        RealType   vVTot(0);
+        RealType   fIX(0);
+        RealType   fIY(0);
+        RealType   fIZ(0);
+
+#if GMX_SIMD_HAVE_REAL
+        alignas(GMX_SIMD_ALIGNMENT) int preloadIi[DataTypes::simdRealWidth];
+        alignas(GMX_SIMD_ALIGNMENT) int preloadIs[DataTypes::simdRealWidth];
+#else
+        int preloadIi[DataTypes::simdRealWidth];
+        int preloadIs[DataTypes::simdRealWidth];
+#endif
+        for (int s = 0; s < DataTypes::simdRealWidth; s++)
         {
-            int            tj[NSTATES];
-            const int      jnr = jjnr[k];
-            const int      j3  = 3 * jnr;
-            RealType       c6[NSTATES], c12[NSTATES], qq[NSTATES], vCoul[NSTATES], vVdw[NSTATES];
-            RealType       r, rInv, rp, rpm2;
-            RealType       alphaVdwEff, alphaCoulEff, sigma6[NSTATES];
-            const RealType dX  = ix - x[j3];
-            const RealType dY  = iy - x[j3 + 1];
-            const RealType dZ  = iz - x[j3 + 2];
+            preloadIi[s] = ii;
+            preloadIs[s] = shift[n];
+        }
+        IntType ii_s = gmx::load<IntType>(preloadIi);
+
+        for (int k = nj0; k < nj1; k += DataTypes::simdRealWidth)
+        {
+            RealType r, rInv;
+
+#if GMX_SIMD_HAVE_REAL
+            alignas(GMX_SIMD_ALIGNMENT) real    preloadPairIsValid[DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) real    preloadPairIncluded[DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) int32_t preloadJnr[DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) int32_t typeIndices[NSTATES][DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) real    preloadQq[NSTATES][DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) real gmx_unused preloadSigma6[NSTATES][DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) real gmx_unused preloadAlphaVdwEff[DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) real gmx_unused preloadAlphaCoulEff[DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) real preloadLjPmeC6Grid[NSTATES][DataTypes::simdRealWidth];
+#else
+            real            preloadPairIsValid[DataTypes::simdRealWidth];
+            real            preloadPairIncluded[DataTypes::simdRealWidth];
+            int             preloadJnr[DataTypes::simdRealWidth];
+            int             typeIndices[NSTATES][DataTypes::simdRealWidth];
+            real            preloadQq[NSTATES][DataTypes::simdRealWidth];
+            real gmx_unused preloadSigma6[NSTATES][DataTypes::simdRealWidth];
+            real gmx_unused preloadAlphaVdwEff[DataTypes::simdRealWidth];
+            real gmx_unused preloadAlphaCoulEff[DataTypes::simdRealWidth];
+            real            preloadLjPmeC6Grid[NSTATES][DataTypes::simdRealWidth];
+#endif
+            for (int s = 0; s < DataTypes::simdRealWidth; s++)
+            {
+                if (k + s < nj1)
+                {
+                    preloadPairIsValid[s] = true;
+                    /* Check if this pair on the exclusions list.*/
+                    preloadPairIncluded[s]  = (nlist.excl_fep.empty() || nlist.excl_fep[k + s]);
+                    const int jnr           = jjnr[k + s];
+                    preloadJnr[s]           = jnr;
+                    typeIndices[STATE_A][s] = ntiA + typeA[jnr];
+                    typeIndices[STATE_B][s] = ntiB + typeB[jnr];
+                    preloadQq[STATE_A][s]   = iqA * chargeA[jnr];
+                    preloadQq[STATE_B][s]   = iqB * chargeB[jnr];
+
+                    for (int i = 0; i < NSTATES; i++)
+                    {
+                        if constexpr (vdwInteractionTypeIsEwald)
+                        {
+                            preloadLjPmeC6Grid[i][s] = nbfp_grid[2 * typeIndices[i][s]];
+                        }
+                        else
+                        {
+                            preloadLjPmeC6Grid[i][s] = 0;
+                        }
+                        if constexpr (useSoftCore)
+                        {
+                            const real c6  = nbfp[2 * typeIndices[i][s]];
+                            const real c12 = nbfp[2 * typeIndices[i][s] + 1];
+                            if (c6 > 0 && c12 > 0)
+                            {
+                                /* c12 is stored scaled with 12.0 and c6 is scaled with 6.0 - correct for this */
+                                preloadSigma6[i][s] = 0.5_real * c12 / c6;
+                                if (preloadSigma6[i][s]
+                                    < sigma6_min) /* for disappearing coul and vdw with soft core at the same time */
+                                {
+                                    preloadSigma6[i][s] = sigma6_min;
+                                }
+                            }
+                            else
+                            {
+                                preloadSigma6[i][s] = sigma6_def;
+                            }
+                        }
+                    }
+                    if constexpr (useSoftCore)
+                    {
+                        /* only use softcore if one of the states has a zero endstate - softcore is for avoiding infinities!*/
+                        const real c12A = nbfp[2 * typeIndices[STATE_A][s] + 1];
+                        const real c12B = nbfp[2 * typeIndices[STATE_B][s] + 1];
+                        if (c12A > 0 && c12B > 0)
+                        {
+                            preloadAlphaVdwEff[s]  = 0;
+                            preloadAlphaCoulEff[s] = 0;
+                        }
+                        else
+                        {
+                            preloadAlphaVdwEff[s]  = alpha_vdw;
+                            preloadAlphaCoulEff[s] = alpha_coul;
+                        }
+                    }
+                }
+                else
+                {
+                    preloadJnr[s]          = jjnr[k];
+                    preloadPairIsValid[s]  = false;
+                    preloadPairIncluded[s] = false;
+                    preloadAlphaVdwEff[s]  = 0;
+                    preloadAlphaCoulEff[s] = 0;
+
+                    for (int i = 0; i < NSTATES; i++)
+                    {
+                        typeIndices[STATE_A][s]  = ntiA + typeA[jjnr[k]];
+                        typeIndices[STATE_B][s]  = ntiB + typeB[jjnr[k]];
+                        preloadLjPmeC6Grid[i][s] = 0;
+                        preloadQq[i][s]          = 0;
+                        preloadSigma6[i][s]      = 0;
+                    }
+                }
+            }
+
+            RealType jx, jy, jz;
+            gmx::gatherLoadUTranspose<3>(reinterpret_cast<const real*>(x), preloadJnr, &jx, &jy, &jz);
+
+            const RealType pairIsValid   = gmx::load<RealType>(preloadPairIsValid);
+            const RealType pairIncluded  = gmx::load<RealType>(preloadPairIncluded);
+            const BoolType bPairIncluded = (pairIncluded != zero);
+            const BoolType bPairExcluded = (pairIncluded == zero && pairIsValid != zero);
+
+            const RealType dX  = ix - jx;
+            const RealType dY  = iy - jy;
+            const RealType dZ  = iz - jz;
             const RealType rSq = dX * dX + dY * dY + dZ * dZ;
-            RealType       fScalC[NSTATES], fScalV[NSTATES];
-            /* Check if this pair on the exlusions list.*/
-            const bool bPairIncluded = nlist.excl_fep.empty() || nlist.excl_fep[k];
 
-            if (rSq >= rcutoff_max2 && bPairIncluded)
+            BoolType withinCutoffMask = (rSq < rcutoff_max2);
+
+            if (!gmx::anyTrue(withinCutoffMask || bPairExcluded))
             {
                 /* We save significant time by skipping all code below.
                  * Note that with soft-core interactions, the actual cut-off
@@ -430,38 +567,55 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                  * when using Ewald: the reciprocal-space
                  * Ewald component still needs to be subtracted.
                  */
-
                 continue;
             }
-            npair_within_cutoff++;
-
-            if (rSq > rlistSquared)
+            else
             {
-                numExcludedPairsBeyondRlist++;
+                havePairsWithinCutoff = true;
             }
 
-            if (rSq > 0)
+            if (gmx::anyTrue(rlistSquared < rSq && bPairExcluded))
             {
-                /* Note that unlike in the nbnxn kernels, we do not need
-                 * to clamp the value of rSq before taking the invsqrt
-                 * to avoid NaN in the LJ calculation, since here we do
-                 * not calculate LJ interactions when C6 and C12 are zero.
-                 */
+                haveExcludedPairsBeyondRlist = true;
+            }
 
-                rInv = gmx::invsqrt(rSq);
-                r    = rSq * rInv;
+            const IntType  jnr_s    = gmx::load<IntType>(preloadJnr);
+            const BoolType bIiEqJnr = gmx::cvtIB2B(ii_s == jnr_s);
+
+            RealType            c6[NSTATES];
+            RealType            c12[NSTATES];
+            RealType gmx_unused sigma6[NSTATES];
+            RealType            qq[NSTATES];
+            RealType gmx_unused ljPmeC6Grid[NSTATES];
+            RealType gmx_unused alphaVdwEff;
+            RealType gmx_unused alphaCoulEff;
+            for (int i = 0; i < NSTATES; i++)
+            {
+                gmx::gatherLoadTranspose<2>(nbfp.data(), typeIndices[i], &c6[i], &c12[i]);
+                qq[i]          = gmx::load<RealType>(preloadQq[i]);
+                ljPmeC6Grid[i] = gmx::load<RealType>(preloadLjPmeC6Grid[i]);
+                if constexpr (useSoftCore)
+                {
+                    sigma6[i] = gmx::load<RealType>(preloadSigma6[i]);
+                }
             }
-            else
+            if constexpr (useSoftCore)
             {
-                /* The force at r=0 is zero, because of symmetry.
-                 * But note that the potential is in general non-zero,
-                 * since the soft-cored r will be non-zero.
-                 */
-                rInv = 0;
-                r    = 0;
+                alphaVdwEff  = gmx::load<RealType>(preloadAlphaVdwEff);
+                alphaCoulEff = gmx::load<RealType>(preloadAlphaCoulEff);
             }
 
-            if (useSoftCore)
+            BoolType rSqValid = (zero < rSq);
+
+            /* The force at r=0 is zero, because of symmetry.
+             * But note that the potential is in general non-zero,
+             * since the soft-cored r will be non-zero.
+             */
+            rInv = gmx::maskzInvsqrt(rSq, rSqValid);
+            r    = rSq * rInv;
+
+            RealType gmx_unused rp, rpm2;
+            if constexpr (useSoftCore)
             {
                 rpm2 = rSq * rSq;  /* r4 */
                 rp   = rpm2 * rSq; /* r6 */
@@ -473,79 +627,45 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                  * the simplest math and cheapest code.
                  */
                 rpm2 = rInv * rInv;
-                rp   = 1;
+                rp   = one;
             }
 
-            RealType fScal = 0;
+            RealType fScal(0);
 
-            qq[STATE_A] = iqA * chargeA[jnr];
-            qq[STATE_B] = iqB * chargeB[jnr];
-
-            tj[STATE_A] = ntiA + 2 * typeA[jnr];
-            tj[STATE_B] = ntiB + 2 * typeB[jnr];
-
-            if (bPairIncluded)
+            /* The following block is masked to only calculate values having bPairIncluded. If
+             * bPairIncluded is true then withinCutoffMask must also be true. */
+            if (gmx::anyTrue(withinCutoffMask && bPairIncluded))
             {
-                c6[STATE_A] = nbfp[tj[STATE_A]];
-                c6[STATE_B] = nbfp[tj[STATE_B]];
-
+                RealType fScalC[NSTATES], fScalV[NSTATES];
+                RealType vCoul[NSTATES], vVdw[NSTATES];
                 for (int i = 0; i < NSTATES; i++)
                 {
-                    c12[i] = nbfp[tj[i] + 1];
-                    if (useSoftCore)
+                    fScalC[i] = zero;
+                    fScalV[i] = zero;
+                    vCoul[i]  = zero;
+                    vVdw[i]   = zero;
+
+                    RealType gmx_unused rInvC, rInvV, rC, rV, rPInvC, rPInvV;
+
+                    /* The following block is masked to require (qq[i] != 0 || c6[i] != 0 || c12[i]
+                     * != 0) in addition to bPairIncluded, which in turn requires withinCutoffMask. */
+                    BoolType nonZeroState = ((qq[i] != zero || c6[i] != zero || c12[i] != zero)
+                                             && bPairIncluded && withinCutoffMask);
+                    if (gmx::anyTrue(nonZeroState))
                     {
-                        if ((c6[i] > 0) && (c12[i] > 0))
-                        {
-                            /* c12 is stored scaled with 12.0 and c6 is scaled with 6.0 - correct for this */
-                            sigma6[i] = half * c12[i] / c6[i];
-                            if (sigma6[i] < sigma6_min) /* for disappearing coul and vdw with soft core at the same time */
-                            {
-                                sigma6[i] = sigma6_min;
-                            }
-                        }
-                        else
+                        if constexpr (useSoftCore)
                         {
-                            sigma6[i] = sigma6_def;
-                        }
-                    }
-                }
-
-                if (useSoftCore)
-                {
-                    /* only use softcore if one of the states has a zero endstate - softcore is for avoiding infinities!*/
-                    if ((c12[STATE_A] > 0) && (c12[STATE_B] > 0))
-                    {
-                        alphaVdwEff  = 0;
-                        alphaCoulEff = 0;
-                    }
-                    else
-                    {
-                        alphaVdwEff  = alpha_vdw;
-                        alphaCoulEff = alpha_coul;
-                    }
-                }
-
-                for (int i = 0; i < NSTATES; i++)
-                {
-                    fScalC[i] = 0;
-                    fScalV[i] = 0;
-                    vCoul[i]  = 0;
-                    vVdw[i]   = 0;
-
-                    RealType rInvC, rInvV, rC, rV, rPInvC, rPInvV;
+                            RealType divisor      = (alphaCoulEff * lFacCoul[i] * sigma6[i] + rp);
+                            BoolType validDivisor = (zero < divisor);
+                            rPInvC                = gmx::maskzInv(divisor, validDivisor);
+                            pthRoot(rPInvC, &rInvC, &rC, validDivisor);
 
-                    /* Only spend time on A or B state if it is non-zero */
-                    if ((qq[i] != 0) || (c6[i] != 0) || (c12[i] != 0))
-                    {
-                        /* this section has to be inside the loop because of the dependence on sigma6 */
-                        if (useSoftCore)
-                        {
-                            rPInvC = one / (alphaCoulEff * lFacCoul[i] * sigma6[i] + rp);
-                            pthRoot(rPInvC, &rInvC, &rC);
-                            if (scLambdasOrAlphasDiffer)
+                            if constexpr (scLambdasOrAlphasDiffer)
                             {
-                                rPInvV = one / (alphaVdwEff * lFacVdw[i] * sigma6[i] + rp);
-                                pthRoot(rPInvV, &rInvV, &rV);
+                                RealType divisor      = (alphaVdwEff * lFacVdw[i] * sigma6[i] + rp);
+                                BoolType validDivisor = (zero < divisor);
+                                rPInvV                = gmx::maskzInv(divisor, validDivisor);
+                                pthRoot(rPInvV, &rInvV, &rV, validDivisor);
                             }
                             else
                             {
@@ -557,25 +677,31 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                         }
                         else
                         {
-                            rPInvC = 1;
+                            rPInvC = one;
                             rInvC  = rInv;
                             rC     = r;
 
-                            rPInvV = 1;
+                            rPInvV = one;
                             rInvV  = rInv;
                             rV     = r;
                         }
 
-                        /* Only process the coulomb interactions if we have charges,
-                         * and if we either include all entries in the list (no cutoff
+                        /* Only process the coulomb interactions if we either
+                         * include all entries in the list (no cutoff
                          * used in the kernel), or if we are within the cutoff.
                          */
-                        bool computeElecInteraction = (elecInteractionTypeIsEwald && r < rCoulomb)
-                                                      || (!elecInteractionTypeIsEwald && rC < rCoulomb);
-
-                        if ((qq[i] != 0) && computeElecInteraction)
+                        BoolType computeElecInteraction;
+                        if constexpr (elecInteractionTypeIsEwald)
                         {
-                            if (elecInteractionTypeIsEwald)
+                            computeElecInteraction = (r < rCoulomb && qq[i] != zero && bPairIncluded);
+                        }
+                        else
+                        {
+                            computeElecInteraction = (rC < rCoulomb && qq[i] != zero && bPairIncluded);
+                        }
+                        if (gmx::anyTrue(computeElecInteraction))
+                        {
+                            if constexpr (elecInteractionTypeIsEwald)
                             {
                                 vCoul[i]  = ewaldPotential(qq[i], rInvC, sh_ewald);
                                 fScalC[i] = ewaldScalarForce(qq[i], rInvC);
@@ -585,19 +711,30 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                                 vCoul[i]  = reactionFieldPotential(qq[i], rInvC, rC, krf, crf);
                                 fScalC[i] = reactionFieldScalarForce(qq[i], rInvC, rC, krf, two);
                             }
+
+                            vCoul[i]  = gmx::selectByMask(vCoul[i], computeElecInteraction);
+                            fScalC[i] = gmx::selectByMask(fScalC[i], computeElecInteraction);
                         }
 
-                        /* Only process the VDW interactions if we have
-                         * some non-zero parameters, and if we either
+                        /* Only process the VDW interactions if we either
                          * include all entries in the list (no cutoff used
                          * in the kernel), or if we are within the cutoff.
                          */
-                        bool computeVdwInteraction = (vdwInteractionTypeIsEwald && r < rVdw)
-                                                     || (!vdwInteractionTypeIsEwald && rV < rVdw);
-                        if ((c6[i] != 0 || c12[i] != 0) && computeVdwInteraction)
+                        BoolType computeVdwInteraction;
+                        if constexpr (vdwInteractionTypeIsEwald)
+                        {
+                            computeVdwInteraction =
+                                    (r < rVdw && (c6[i] != 0 || c12[i] != 0) && bPairIncluded);
+                        }
+                        else
+                        {
+                            computeVdwInteraction =
+                                    (rV < rVdw && (c6[i] != 0 || c12[i] != 0) && bPairIncluded);
+                        }
+                        if (gmx::anyTrue(computeVdwInteraction))
                         {
                             RealType rInv6;
-                            if (useSoftCore)
+                            if constexpr (useSoftCore)
                             {
                                 rInv6 = rPInvV;
                             }
@@ -612,26 +749,33 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                                     vVdw6, vVdw12, c6[i], c12[i], repulsionShift, dispersionShift, oneSixth, oneTwelfth);
                             fScalV[i] = lennardJonesScalarForce(vVdw6, vVdw12);
 
-                            if (vdwInteractionTypeIsEwald)
+                            if constexpr (vdwInteractionTypeIsEwald)
                             {
                                 /* Subtract the grid potential at the cut-off */
-                                vVdw[i] += ewaldLennardJonesGridSubtract(
-                                        nbfp_grid[tj[i]], shLjEwald, oneSixth);
+                                vVdw[i] = vVdw[i]
+                                          + gmx::selectByMask(ewaldLennardJonesGridSubtract(
+                                                                      ljPmeC6Grid[i], shLjEwald, oneSixth),
+                                                              computeVdwInteraction);
                             }
 
-                            if (vdwModifierIsPotSwitch)
+                            if constexpr (vdwModifierIsPotSwitch)
                             {
-                                RealType d        = rV - ic.rvdw_switch;
-                                d                 = (d > zero) ? d : zero;
-                                const RealType d2 = d * d;
+                                RealType d             = rV - rVdwSwitch;
+                                BoolType zeroMask      = zero < d;
+                                BoolType potSwitchMask = rV < rVdw;
+                                d                      = gmx::selectByMask(d, zeroMask);
+                                const RealType d2      = d * d;
                                 const RealType sw =
                                         one + d2 * d * (vdw_swV3 + d * (vdw_swV4 + d * vdw_swV5));
                                 const RealType dsw = d2 * (vdw_swF2 + d * (vdw_swF3 + d * vdw_swF4));
 
                                 fScalV[i] = potSwitchScalarForceMod(
-                                        fScalV[i], vVdw[i], sw, rV, rVdw, dsw, zero);
-                                vVdw[i] = potSwitchPotentialMod(vVdw[i], sw, rV, rVdw, zero);
+                                        fScalV[i], vVdw[i], sw, rV, dsw, potSwitchMask);
+                                vVdw[i] = potSwitchPotentialMod(vVdw[i], sw, potSwitchMask);
                             }
+
+                            vVdw[i]   = gmx::selectByMask(vVdw[i], computeVdwInteraction);
+                            fScalV[i] = gmx::selectByMask(fScalV[i], computeVdwInteraction);
                         }
 
                         /* fScalC (and fScalV) now contain: dV/drC * rC
@@ -639,57 +783,69 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                          * Further down we first multiply by r^p-2 and then by
                          * the vector r, which in total gives: dV/drC * (r/rC)^1-p
                          */
-                        fScalC[i] *= rPInvC;
-                        fScalV[i] *= rPInvV;
-                    }
-                } // end for (int i = 0; i < NSTATES; i++)
-
-                /* Assemble A and B states */
-                for (int i = 0; i < NSTATES; i++)
+                        fScalC[i] = fScalC[i] * rPInvC;
+                        fScalV[i] = fScalV[i] * rPInvV;
+                    } // end of block requiring nonZeroState
+                }     // end for (int i = 0; i < NSTATES; i++)
+
+                /* Assemble A and B states. */
+                BoolType assembleStates = (bPairIncluded && withinCutoffMask);
+                if (gmx::anyTrue(assembleStates))
                 {
-                    vCTot += LFC[i] * vCoul[i];
-                    vVTot += LFV[i] * vVdw[i];
+                    for (int i = 0; i < NSTATES; i++)
+                    {
+                        vCTot = vCTot + LFC[i] * vCoul[i];
+                        vVTot = vVTot + LFV[i] * vVdw[i];
 
-                    fScal += LFC[i] * fScalC[i] * rpm2;
-                    fScal += LFV[i] * fScalV[i] * rpm2;
+                        fScal = fScal + LFC[i] * fScalC[i] * rpm2;
+                        fScal = fScal + LFV[i] * fScalV[i] * rpm2;
 
-                    if (useSoftCore)
-                    {
-                        dvdlCoul += vCoul[i] * DLF[i]
-                                    + LFC[i] * alphaCoulEff * dlFacCoul[i] * fScalC[i] * sigma6[i];
-                        dvdlVdw += vVdw[i] * DLF[i]
-                                   + LFV[i] * alphaVdwEff * dlFacVdw[i] * fScalV[i] * sigma6[i];
-                    }
-                    else
-                    {
-                        dvdlCoul += vCoul[i] * DLF[i];
-                        dvdlVdw += vVdw[i] * DLF[i];
+                        if constexpr (useSoftCore)
+                        {
+                            dvdlCoul = dvdlCoul + vCoul[i] * DLF[i]
+                                       + LFC[i] * alphaCoulEff * dlFacCoul[i] * fScalC[i] * sigma6[i];
+                            dvdlVdw = dvdlVdw + vVdw[i] * DLF[i]
+                                      + LFV[i] * alphaVdwEff * dlFacVdw[i] * fScalV[i] * sigma6[i];
+                        }
+                        else
+                        {
+                            dvdlCoul = dvdlCoul + vCoul[i] * DLF[i];
+                            dvdlVdw  = dvdlVdw + vVdw[i] * DLF[i];
+                        }
                     }
                 }
-            } // end if (bPairIncluded)
-            else if (icoul == NbkernelElecType::ReactionField)
+            } // end of block requiring bPairIncluded && withinCutoffMask
+            /* In the following block bPairIncluded should be false in the masks. */
+            if (icoul == NbkernelElecType::ReactionField)
             {
-                /* For excluded pairs, which are only in this pair list when
-                 * using the Verlet scheme, we don't use soft-core.
-                 * As there is no singularity, there is no need for soft-core.
-                 */
-                const real FF = -two * krf;
-                RealType   VV = krf * rSq - crf;
+                const BoolType computeReactionField = bPairExcluded;
 
-                if (ii == jnr)
+                if (gmx::anyTrue(computeReactionField))
                 {
-                    VV *= half;
-                }
+                    /* For excluded pairs we don't use soft-core.
+                     * As there is no singularity, there is no need for soft-core.
+                     */
+                    const RealType FF = -two * krf;
+                    RealType       VV = krf * rSq - crf;
 
-                for (int i = 0; i < NSTATES; i++)
-                {
-                    vCTot += LFC[i] * qq[i] * VV;
-                    fScal += LFC[i] * qq[i] * FF;
-                    dvdlCoul += DLF[i] * qq[i] * VV;
+                    /* If ii == jnr the i particle (ii) has itself (jnr)
+                     * in its neighborlist. This corresponds to a self-interaction
+                     * that will occur twice. Scale it down by 50% to only include
+                     * it once.
+                     */
+                    VV = VV * gmx::blend(one, half, bIiEqJnr);
+
+                    for (int i = 0; i < NSTATES; i++)
+                    {
+                        vCTot = vCTot + gmx::selectByMask(LFC[i] * qq[i] * VV, computeReactionField);
+                        fScal = fScal + gmx::selectByMask(LFC[i] * qq[i] * FF, computeReactionField);
+                        dvdlCoul = dvdlCoul + gmx::selectByMask(DLF[i] * qq[i] * VV, computeReactionField);
+                    }
                 }
             }
 
-            if (elecInteractionTypeIsEwald && (r < rCoulomb || !bPairIncluded))
+            const BoolType computeElecEwaldInteraction = (bPairExcluded || r < rCoulomb);
+            if (elecInteractionTypeIsEwald && gmx::anyTrue(computeElecEwaldInteraction))
             {
                 /* See comment in the preamble. When using Ewald interactions
                  * (unless we use a switch modifier) we subtract the reciprocal-space
@@ -699,39 +855,33 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                  * the softcore to the entire electrostatic interaction,
                  * including the reciprocal-space component.
                  */
-                real v_lr, f_lr;
+                RealType v_lr, f_lr;
 
-                const RealType ewrt   = r * coulombTableScale;
-                IntType        ewitab = static_cast<IntType>(ewrt);
-                const RealType eweps  = ewrt - ewitab;
-                ewitab                = 4 * ewitab;
-                f_lr                  = ewtab[ewitab] + eweps * ewtab[ewitab + 1];
-                v_lr = (ewtab[ewitab + 2] - coulombTableScaleInvHalf * eweps * (ewtab[ewitab] + f_lr));
-                f_lr *= rInv;
+                pmeCoulombCorrectionVF(rSq, ewaldBeta, &v_lr, &f_lr, rSqValid);
+                f_lr = f_lr * rInv * rInv;
 
                 /* Note that any possible Ewald shift has already been applied in
                  * the normal interaction part above.
                  */
 
-                if (ii == jnr)
-                {
-                    /* If we get here, the i particle (ii) has itself (jnr)
-                     * in its neighborlist. This can only happen with the Verlet
-                     * scheme, and corresponds to a self-interaction that will
-                     * occur twice. Scale it down by 50% to only include it once.
-                     */
-                    v_lr *= half;
-                }
+                /* If ii == jnr the i particle (ii) has itself (jnr)
+                 * in its neighborlist. This corresponds to a self-interaction
+                 * that will occur twice. Scale it down by 50% to only include
+                 * it once.
+                 */
+                v_lr = v_lr * gmx::blend(one, half, bIiEqJnr);
 
                 for (int i = 0; i < NSTATES; i++)
                 {
-                    vCTot -= LFC[i] * qq[i] * v_lr;
-                    fScal -= LFC[i] * qq[i] * f_lr;
-                    dvdlCoul -= (DLF[i] * qq[i]) * v_lr;
+                    vCTot = vCTot - gmx::selectByMask(LFC[i] * qq[i] * v_lr, computeElecEwaldInteraction);
+                    fScal = fScal - gmx::selectByMask(LFC[i] * qq[i] * f_lr, computeElecEwaldInteraction);
+                    dvdlCoul = dvdlCoul
+                               - gmx::selectByMask(DLF[i] * qq[i] * v_lr, computeElecEwaldInteraction);
                 }
             }
 
-            if (vdwInteractionTypeIsEwald && (r < rVdw || !bPairIncluded))
+            const BoolType computeVdwEwaldInteraction = (bPairExcluded || r < rVdw);
+            if (vdwInteractionTypeIsEwald && gmx::anyTrue(computeVdwEwaldInteraction))
             {
                 /* See comment in the preamble. When using LJ-Ewald interactions
                  * (unless we use a switch modifier) we subtract the reciprocal-space
@@ -741,147 +891,105 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                  * the softcore to the entire VdW interaction,
                  * including the reciprocal-space component.
                  */
-                /* We could also use the analytical form here
-                 * iso a table, but that can cause issues for
-                 * r close to 0 for non-interacting pairs.
-                 */
 
-                const RealType rs   = rSq * rInv * vdwTableScale;
-                const IntType  ri   = static_cast<IntType>(rs);
-                const RealType frac = rs - ri;
-                const RealType f_lr = (1 - frac) * tab_ewald_F_lj[ri] + frac * tab_ewald_F_lj[ri + 1];
-                /* TODO: Currently the Ewald LJ table does not contain
-                 * the factor 1/6, we should add this.
-                 */
-                const RealType FF = f_lr * rInv / six;
-                RealType       VV =
-                        (tab_ewald_V_lj[ri] - vdwTableScaleInvHalf * frac * (tab_ewald_F_lj[ri] + f_lr))
-                        / six;
-
-                if (ii == jnr)
-                {
-                    /* If we get here, the i particle (ii) has itself (jnr)
-                     * in its neighborlist. This can only happen with the Verlet
-                     * scheme, and corresponds to a self-interaction that will
-                     * occur twice. Scale it down by 50% to only include it once.
-                     */
-                    VV *= half;
-                }
+                RealType v_lr, f_lr;
+                pmeLJCorrectionVF(
+                        rInv, rSq, ewaldLJCoeffSq, ewaldLJCoeffSixDivSix, &v_lr, &f_lr, computeVdwEwaldInteraction, bIiEqJnr);
+                v_lr = v_lr * oneSixth;
 
                 for (int i = 0; i < NSTATES; i++)
                 {
-                    const real c6grid = nbfp_grid[tj[i]];
-                    vVTot += LFV[i] * c6grid * VV;
-                    fScal += LFV[i] * c6grid * FF;
-                    dvdlVdw += (DLF[i] * c6grid) * VV;
+                    vVTot = vVTot + gmx::selectByMask(LFV[i] * ljPmeC6Grid[i] * v_lr, computeVdwEwaldInteraction);
+                    fScal = fScal + gmx::selectByMask(LFV[i] * ljPmeC6Grid[i] * f_lr, computeVdwEwaldInteraction);
+                    dvdlVdw = dvdlVdw + gmx::selectByMask(DLF[i] * ljPmeC6Grid[i] * v_lr, computeVdwEwaldInteraction);
                 }
             }
 
-            if (doForces)
+            if (doForces && gmx::anyTrue(fScal != zero))
             {
-                const real tX = fScal * dX;
-                const real tY = fScal * dY;
-                const real tZ = fScal * dZ;
-                fIX           = fIX + tX;
-                fIY           = fIY + tY;
-                fIZ           = fIZ + tZ;
-                /* OpenMP atomics are expensive, but this kernels is also
-                 * expensive, so we can take this hit, instead of using
-                 * thread-local output buffers and extra reduction.
-                 *
-                 * All the OpenMP regions in this file are trivial and should
-                 * not throw, so no need for try/catch.
-                 */
-#pragma omp atomic
-                f[j3] -= tX;
-#pragma omp atomic
-                f[j3 + 1] -= tY;
-#pragma omp atomic
-                f[j3 + 2] -= tZ;
+                const RealType tX = fScal * dX;
+                const RealType tY = fScal * dY;
+                const RealType tZ = fScal * dZ;
+                fIX               = fIX + tX;
+                fIY               = fIY + tY;
+                fIZ               = fIZ + tZ;
+
+                gmx::transposeScatterDecrU<3>(
+                        reinterpret_cast<real*>(threadForceBuffer), preloadJnr, tX, tY, tZ);
             }
-        } // end for (int k = nj0; k < nj1; k++)
-
-        /* The atomics below are expensive with many OpenMP threads.
-         * Here unperturbed i-particles will usually only have a few
-         * (perturbed) j-particles in the list. Thus with a buffered list
-         * we can skip a significant number of i-reductions with a check.
-         */
-        if (npair_within_cutoff > 0)
+        } // end for (int k = nj0; k < nj1; k += DataTypes::simdRealWidth)
+
+        if (havePairsWithinCutoff)
         {
             if (doForces)
             {
-#pragma omp atomic
-                f[ii3] += fIX;
-#pragma omp atomic
-                f[ii3 + 1] += fIY;
-#pragma omp atomic
-                f[ii3 + 2] += fIZ;
+                gmx::transposeScatterIncrU<3>(
+                        reinterpret_cast<real*>(threadForceBuffer), preloadIi, fIX, fIY, fIZ);
             }
             if (doShiftForces)
             {
-#pragma omp atomic
-                fshift[is3] += fIX;
-#pragma omp atomic
-                fshift[is3 + 1] += fIY;
-#pragma omp atomic
-                fshift[is3 + 2] += fIZ;
+                gmx::transposeScatterIncrU<3>(
+                        reinterpret_cast<real*>(threadForceShiftBuffer), preloadIs, fIX, fIY, fIZ);
             }
             if (doPotential)
             {
                 int ggid = gid[n];
-#pragma omp atomic
-                energygrp_elec[ggid] += vCTot;
-#pragma omp atomic
-                energygrp_vdw[ggid] += vVTot;
+                threadVc[ggid] += gmx::reduce(vCTot);
+                threadVv[ggid] += gmx::reduce(vVTot);
             }
         }
     } // end for (int n = 0; n < nri; n++)
 
-#pragma omp atomic
-    dvdl[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)] += dvdlCoul;
-#pragma omp atomic
-    dvdl[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)] += dvdlVdw;
+    if (gmx::anyTrue(dvdlCoul != zero))
+    {
+        threadDvdl[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)] += gmx::reduce(dvdlCoul);
+    }
+    if (gmx::anyTrue(dvdlVdw != zero))
+    {
+        threadDvdl[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)] += gmx::reduce(dvdlVdw);
+    }
 
     /* Estimate flops, average for free energy stuff:
      * 12  flops per outer iteration
      * 150 flops per inner iteration
+     * TODO: Update the number of flops and/or use different counts for different code paths.
      */
     atomicNrnbIncrement(nrnb, eNR_NBKERNEL_FREE_ENERGY, nlist.nri * 12 + nlist.jindex[nri] * 150);
 
-    if (numExcludedPairsBeyondRlist > 0)
+    if (haveExcludedPairsBeyondRlist > 0)
     {
         gmx_fatal(FARGS,
-                  "There are %d perturbed non-bonded pair interactions beyond the pair-list cutoff "
+                  "There are perturbed non-bonded pair interactions beyond the pair-list cutoff "
                   "of %g nm, which is not supported. This can happen because the system is "
                   "unstable or because intra-molecular interactions at long distances are "
                   "excluded. If the "
                   "latter is the case, you can try to increase nstlist or rlist to avoid this."
                   "The error is likely triggered by the use of couple-intramol=no "
                   "and the maximal distance in the decoupled molecule exceeding rlist.",
-                  numExcludedPairsBeyondRlist,
                   rlist);
     }
 }
 
-typedef void (*KernelFunction)(const t_nblist&                nlist,
-                               gmx::ArrayRef<const gmx::RVec> coords,
-                               gmx::ForceWithShiftForces*     forceWithShiftForces,
-                               const int                      ntype,
-                               const real                     rlist,
-                               const interaction_const_t&     ic,
-                               gmx::ArrayRef<const gmx::RVec> shiftvec,
-                               gmx::ArrayRef<const real>      nbfp,
-                               gmx::ArrayRef<const real>      nbfp_grid,
-                               gmx::ArrayRef<const real>      chargeA,
-                               gmx::ArrayRef<const real>      chargeB,
-                               gmx::ArrayRef<const int>       typeA,
-                               gmx::ArrayRef<const int>       typeB,
-                               int                            flags,
-                               gmx::ArrayRef<const real>      lambda,
-                               gmx::ArrayRef<real>            dvdl,
-                               gmx::ArrayRef<real>            energygrp_elec,
-                               gmx::ArrayRef<real>            energygrp_vdw,
-                               t_nrnb* gmx_restrict           nrnb);
+typedef void (*KernelFunction)(const t_nblist&                           nlist,
+                               const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
+                               const int                                 ntype,
+                               const real                                rlist,
+                               const interaction_const_t&                ic,
+                               gmx::ArrayRef<const gmx::RVec>            shiftvec,
+                               gmx::ArrayRef<const real>                 nbfp,
+                               gmx::ArrayRef<const real>                 nbfp_grid,
+                               gmx::ArrayRef<const real>                 chargeA,
+                               gmx::ArrayRef<const real>                 chargeB,
+                               gmx::ArrayRef<const int>                  typeA,
+                               gmx::ArrayRef<const int>                  typeB,
+                               int                                       flags,
+                               gmx::ArrayRef<const real>                 lambda,
+                               t_nrnb* gmx_restrict                      nrnb,
+                               gmx::RVec*                                threadForceBuffer,
+                               rvec*                                     threadForceShiftBuffer,
+                               gmx::ArrayRef<real>                       threadVc,
+                               gmx::ArrayRef<real>                       threadVv,
+                               gmx::ArrayRef<real>                       threadDvdl);
 
 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
 static KernelFunction dispatchKernelOnUseSimd(const bool useSimd)
@@ -889,8 +997,7 @@ static KernelFunction dispatchKernelOnUseSimd(const bool useSimd)
     if (useSimd)
     {
 #if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS && GMX_USE_SIMD_KERNELS
-        /* FIXME: Here SimdDataTypes should be used to enable SIMD. So far, the code in nb_free_energy_kernel is not adapted to SIMD */
-        return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
+        return (nb_free_energy_kernel<SimdDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
 #else
         return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
 #endif
@@ -996,26 +1103,27 @@ static KernelFunction dispatchKernel(const bool                 scLambdasOrAlpha
 }
 
 
-void gmx_nb_free_energy_kernel(const t_nblist&                nlist,
-                               gmx::ArrayRef<const gmx::RVec> coords,
-                               gmx::ForceWithShiftForces*     ff,
-                               const bool                     useSimd,
-                               const int                      ntype,
-                               const real                     rlist,
-                               const interaction_const_t&     ic,
-                               gmx::ArrayRef<const gmx::RVec> shiftvec,
-                               gmx::ArrayRef<const real>      nbfp,
-                               gmx::ArrayRef<const real>      nbfp_grid,
-                               gmx::ArrayRef<const real>      chargeA,
-                               gmx::ArrayRef<const real>      chargeB,
-                               gmx::ArrayRef<const int>       typeA,
-                               gmx::ArrayRef<const int>       typeB,
-                               int                            flags,
-                               gmx::ArrayRef<const real>      lambda,
-                               gmx::ArrayRef<real>            dvdl,
-                               gmx::ArrayRef<real>            energygrp_elec,
-                               gmx::ArrayRef<real>            energygrp_vdw,
-                               t_nrnb*                        nrnb)
+void gmx_nb_free_energy_kernel(const t_nblist&                           nlist,
+                               const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
+                               const bool                                useSimd,
+                               const int                                 ntype,
+                               const real                                rlist,
+                               const interaction_const_t&                ic,
+                               gmx::ArrayRef<const gmx::RVec>            shiftvec,
+                               gmx::ArrayRef<const real>                 nbfp,
+                               gmx::ArrayRef<const real>                 nbfp_grid,
+                               gmx::ArrayRef<const real>                 chargeA,
+                               gmx::ArrayRef<const real>                 chargeB,
+                               gmx::ArrayRef<const int>                  typeA,
+                               gmx::ArrayRef<const int>                  typeB,
+                               int                                       flags,
+                               gmx::ArrayRef<const real>                 lambda,
+                               t_nrnb*                                   nrnb,
+                               gmx::RVec*                                threadForceBuffer,
+                               rvec*                                     threadForceShiftBuffer,
+                               gmx::ArrayRef<real>                       threadVc,
+                               gmx::ArrayRef<real>                       threadVv,
+                               gmx::ArrayRef<real>                       threadDvdl)
 {
     GMX_ASSERT(EEL_PME_EWALD(ic.eeltype) || ic.eeltype == CoulombInteractionType::Cut || EEL_RF(ic.eeltype),
                "Unsupported eeltype with free energy");
@@ -1050,7 +1158,6 @@ void gmx_nb_free_energy_kernel(const t_nblist&                nlist,
                                 ic);
     kernelFunc(nlist,
                coords,
-               ff,
                ntype,
                rlist,
                ic,
@@ -1063,8 +1170,10 @@ void gmx_nb_free_energy_kernel(const t_nblist&                nlist,
                typeB,
                flags,
                lambda,
-               dvdl,
-               energygrp_elec,
-               energygrp_vdw,
-               nrnb);
+               nrnb,
+               threadForceBuffer,
+               threadForceShiftBuffer,
+               threadVc,
+               threadVv,
+               threadDvdl);
 }