SIMD support for nonbonded free-energy kernels
authorMagnus Lundborg <magnus.lundborg@scilifelab.se>
Wed, 29 Sep 2021 09:16:19 +0000 (09:16 +0000)
committerBerk Hess <hess@kth.se>
Wed, 29 Sep 2021 09:16:19 +0000 (09:16 +0000)
15 files changed:
docs/release-notes/2022/major/highlights.rst
docs/release-notes/2022/major/performance.rst
src/gromacs/gmxlib/nonbonded/nb_free_energy.cpp
src/gromacs/gmxlib/nonbonded/nb_free_energy.h
src/gromacs/gmxlib/nonbonded/tests/nb_free_energy.cpp
src/gromacs/mdlib/sim_util.cpp
src/gromacs/nbnxm/CMakeLists.txt
src/gromacs/nbnxm/freeenergydispatch.cpp [new file with mode: 0644]
src/gromacs/nbnxm/freeenergydispatch.h [new file with mode: 0644]
src/gromacs/nbnxm/kerneldispatch.cpp
src/gromacs/nbnxm/nbnxm.cpp
src/gromacs/nbnxm/nbnxm.h
src/gromacs/nbnxm/nbnxm_setup.cpp
src/gromacs/timing/wallcycle.cpp
src/gromacs/timing/wallcycle.h

index 7be8d63f48deb8a9c267b4ac06276b0ab73a4415..f9adc743349a44c635bdbd0400432c18a21fa3e2 100644 (file)
@@ -12,6 +12,8 @@ several new features are available for running simulations. We are extremely
 interested in your feedback on how well the new release works on your
 simulations and hardware. The new features are:
 
+* Free-energy kernels are accelerated using SIMD, which make free-energy
+  calculations up to three times as fast when using GPUs
 * New transformation pull coordinate allows arbibrary mathematical transformations of one of more other pull coordinates
 * Cool quotes music play list
 
index 4e1f6baffc63a1f807c2494679ca82fadb898e99..580ef631a373f50839e0344d252424e58a974a0f 100644 (file)
@@ -13,3 +13,16 @@ Dynamic pairlist generation for energy minimization
 With energy minimization, the pairlist, and domain decomposition when running
 in parallel, is now performed when at least one atom has moved more than the
 half the pairlist buffer size. The pairlist used to be constructed every step.
+
+Nonbonded free-energy kernels use SIMD
+""""""""""""""""""""""""""""""""""""""
+
+Free energy calculation performance is improved by making the nonbonded free-energy
+kernels SIMD accelerated. On AVX2-256 these kernels are 4 to 8 times as fast.
+This should give a noticeable speed-up for most systems, especially if the
+perturbed interaction calculations were a bottleneck. This is particularly the
+case when using GPUs, where the performance improvement of free-energy runs is
+up to a factor of 3.
+
+:issue:`2875`
+:issue:`742`
index d73bc7fb0eca796d1fe8b57cbf7160d3f497b34e..45e09c25bc9169a97087f59b92e16f271236d0b1 100644 (file)
 #include "config.h"
 
 #include <cmath>
+#include <set>
 
 #include <algorithm>
 
 #include "gromacs/gmxlib/nrnb.h"
 #include "gromacs/gmxlib/nonbonded/nonbonded.h"
+#include "gromacs/math/arrayrefwithpadding.h"
 #include "gromacs/math/functions.h"
 #include "gromacs/math/vec.h"
 #include "gromacs/mdtypes/forceoutput.h"
@@ -55,7 +57,9 @@
 #include "gromacs/mdtypes/md_enums.h"
 #include "gromacs/mdtypes/mdatom.h"
 #include "gromacs/mdtypes/nblist.h"
+#include "gromacs/pbcutil/ishift.h"
 #include "gromacs/simd/simd.h"
+#include "gromacs/simd/simd_math.h"
 #include "gromacs/utility/fatalerror.h"
 #include "gromacs/utility/arrayref.h"
 
 //! Scalar (non-SIMD) data types.
 struct ScalarDataTypes
 {
-    using RealType                     = real; //!< The data type to use as real.
-    using IntType                      = int;  //!< The data type to use as int.
-    static constexpr int simdRealWidth = 1;    //!< The width of the RealType.
-    static constexpr int simdIntWidth  = 1;    //!< The width of the IntType.
+    using RealType = real; //!< The data type to use as real.
+    using IntType  = int;  //!< The data type to use as int.
+    using BoolType = bool; //!< The data type to use as bool for real value comparison.
+    static constexpr int simdRealWidth = 1; //!< The width of the RealType.
+    static constexpr int simdIntWidth  = 1; //!< The width of the IntType.
 };
 
 #if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS
 //! SIMD data types.
 struct SimdDataTypes
 {
-    using RealType                     = gmx::SimdReal;         //!< The data type to use as real.
-    using IntType                      = gmx::SimdInt32;        //!< The data type to use as int.
-    static constexpr int simdRealWidth = GMX_SIMD_REAL_WIDTH;   //!< The width of the RealType.
-    static constexpr int simdIntWidth  = GMX_SIMD_FINT32_WIDTH; //!< The width of the IntType.
+    using RealType = gmx::SimdReal;  //!< The data type to use as real.
+    using IntType  = gmx::SimdInt32; //!< The data type to use as int.
+    using BoolType = gmx::SimdBool;  //!< The data type to use as bool for real value comparison.
+    static constexpr int simdRealWidth = GMX_SIMD_REAL_WIDTH; //!< The width of the RealType.
+#    if GMX_SIMD_HAVE_DOUBLE && GMX_DOUBLE
+    static constexpr int simdIntWidth = GMX_SIMD_DINT32_WIDTH; //!< The width of the IntType.
+#    else
+    static constexpr int simdIntWidth = GMX_SIMD_FINT32_WIDTH; //!< The width of the IntType.
+#    endif
 };
 #endif
 
+template<class RealType, class BoolType>
+static inline void
+pmeCoulombCorrectionVF(const RealType rSq, const real beta, RealType* pot, RealType* force, const BoolType mask)
+{
+    const RealType brsq = gmx::selectByMask(rSq * beta * beta, mask);
+    *force              = -brsq * beta * gmx::pmeForceCorrection(brsq);
+    *pot                = beta * gmx::pmePotentialCorrection(brsq);
+}
+
+template<class RealType, class BoolType>
+static inline void pmeLJCorrectionVF(const RealType rInv,
+                                     const RealType rSq,
+                                     const real     ewaldLJCoeffSq,
+                                     const real     ewaldLJCoeffSixDivSix,
+                                     RealType*      pot,
+                                     RealType*      force,
+                                     const BoolType mask,
+                                     const BoolType bIiEqJnr)
+{
+    // We mask rInv to get zero force and potential for masked out pair interactions
+    const RealType rInvSq  = gmx::selectByMask(rInv * rInv, mask);
+    const RealType rInvSix = rInvSq * rInvSq * rInvSq;
+    // Mask rSq to avoid underflow in exp()
+    const RealType coeffSqRSq       = ewaldLJCoeffSq * gmx::selectByMask(rSq, mask);
+    const RealType expNegCoeffSqRSq = gmx::exp(-coeffSqRSq);
+    const RealType poly             = 1.0_real + coeffSqRSq + 0.5_real * coeffSqRSq * coeffSqRSq;
+    *force = rInvSix - expNegCoeffSqRSq * (rInvSix * poly + ewaldLJCoeffSixDivSix);
+    *force = *force * rInvSq;
+    // The self interaction is the limit for r -> 0 which we need to compute separately
+    *pot = gmx::blend(
+            rInvSix * (1.0_real - expNegCoeffSqRSq * poly), 0.5_real * ewaldLJCoeffSixDivSix, bIiEqJnr);
+}
+
 //! Computes r^(1/p) and 1/r^(1/p) for the standard p=6
-template<class RealType>
-static inline void pthRoot(const RealType r, RealType* pthRoot, RealType* invPthRoot)
+template<class RealType, class BoolType>
+static inline void pthRoot(const RealType r, RealType* pthRoot, RealType* invPthRoot, const BoolType mask)
 {
-    *invPthRoot = gmx::invsqrt(std::cbrt(r));
-    *pthRoot    = 1 / (*invPthRoot);
+    RealType cbrtRes = gmx::cbrt(r);
+    *invPthRoot      = gmx::maskzInvsqrt(cbrtRes, mask);
+    *pthRoot         = gmx::maskzInv(*invPthRoot, mask);
 }
 
 template<class RealType>
@@ -159,65 +203,56 @@ static inline RealType lennardJonesPotential(const RealType v6,
 }
 
 /* Ewald LJ */
-static inline real ewaldLennardJonesGridSubtract(const real c6grid, const real potentialShift, const real oneSixth)
+template<class RealType>
+static inline RealType ewaldLennardJonesGridSubtract(const RealType c6grid,
+                                                     const real     potentialShift,
+                                                     const real     oneSixth)
 {
     return (c6grid * potentialShift * oneSixth);
 }
 
 /* LJ Potential switch */
-template<class RealType>
+template<class RealType, class BoolType>
 static inline RealType potSwitchScalarForceMod(const RealType fScalarInp,
                                                const RealType potential,
                                                const RealType sw,
                                                const RealType r,
-                                               const RealType rVdw,
                                                const RealType dsw,
-                                               const real     zero)
+                                               const BoolType mask)
 {
-    if (r < rVdw)
-    {
-        real fScalar = fScalarInp * sw - r * potential * dsw;
-        return (fScalar);
-    }
-    return (zero);
+    /* The mask should select on rV < rVdw */
+    return (gmx::selectByMask(fScalarInp * sw - r * potential * dsw, mask));
 }
-template<class RealType>
-static inline RealType potSwitchPotentialMod(const RealType potentialInp,
-                                             const RealType sw,
-                                             const RealType r,
-                                             const RealType rVdw,
-                                             const real     zero)
+template<class RealType, class BoolType>
+static inline RealType potSwitchPotentialMod(const RealType potentialInp, const RealType sw, const BoolType mask)
 {
-    if (r < rVdw)
-    {
-        real potential = potentialInp * sw;
-        return (potential);
-    }
-    return (zero);
+    /* The mask should select on rV < rVdw */
+    return (gmx::selectByMask(potentialInp * sw, mask));
 }
 
 
 //! Templated free-energy non-bonded kernel
 template<typename DataTypes, bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
-static void nb_free_energy_kernel(const t_nblist&                nlist,
-                                  gmx::ArrayRef<const gmx::RVec> coords,
-                                  gmx::ForceWithShiftForces*     forceWithShiftForces,
-                                  const int                      ntype,
-                                  const real                     rlist,
-                                  const interaction_const_t&     ic,
-                                  gmx::ArrayRef<const gmx::RVec> shiftvec,
-                                  gmx::ArrayRef<const real>      nbfp,
-                                  gmx::ArrayRef<const real>      nbfp_grid,
-                                  gmx::ArrayRef<const real>      chargeA,
-                                  gmx::ArrayRef<const real>      chargeB,
-                                  gmx::ArrayRef<const int>       typeA,
-                                  gmx::ArrayRef<const int>       typeB,
-                                  int                            flags,
-                                  gmx::ArrayRef<const real>      lambda,
-                                  gmx::ArrayRef<real>            dvdl,
-                                  gmx::ArrayRef<real>            energygrp_elec,
-                                  gmx::ArrayRef<real>            energygrp_vdw,
-                                  t_nrnb* gmx_restrict           nrnb)
+static void nb_free_energy_kernel(const t_nblist&                           nlist,
+                                  const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
+                                  const int                                 ntype,
+                                  const real                                rlist,
+                                  const interaction_const_t&                ic,
+                                  gmx::ArrayRef<const gmx::RVec>            shiftvec,
+                                  gmx::ArrayRef<const real>                 nbfp,
+                                  gmx::ArrayRef<const real> gmx_unused      nbfp_grid,
+                                  gmx::ArrayRef<const real>                 chargeA,
+                                  gmx::ArrayRef<const real>                 chargeB,
+                                  gmx::ArrayRef<const int>                  typeA,
+                                  gmx::ArrayRef<const int>                  typeB,
+                                  int                                       flags,
+                                  gmx::ArrayRef<const real>                 lambda,
+                                  t_nrnb* gmx_restrict                      nrnb,
+                                  gmx::RVec*                                threadForceBuffer,
+                                  rvec*                                     threadForceShiftBuffer,
+                                  gmx::ArrayRef<real>                       threadVc,
+                                  gmx::ArrayRef<real>                       threadVv,
+                                  gmx::ArrayRef<real>                       threadDvdl)
 {
 #define STATE_A 0
 #define STATE_B 1
@@ -225,15 +260,15 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
 
     using RealType = typename DataTypes::RealType;
     using IntType  = typename DataTypes::IntType;
+    using BoolType = typename DataTypes::BoolType;
 
-    /* FIXME: How should these be handled with SIMD? */
-    constexpr real oneTwelfth = 1.0 / 12.0;
-    constexpr real oneSixth   = 1.0 / 6.0;
-    constexpr real zero       = 0.0;
-    constexpr real half       = 0.5;
-    constexpr real one        = 1.0;
-    constexpr real two        = 2.0;
-    constexpr real six        = 6.0;
+    constexpr real oneTwelfth = 1.0_real / 12.0_real;
+    constexpr real oneSixth   = 1.0_real / 6.0_real;
+    constexpr real zero       = 0.0_real;
+    constexpr real half       = 0.5_real;
+    constexpr real one        = 1.0_real;
+    constexpr real two        = 2.0_real;
+    constexpr real six        = 6.0_real;
 
     // Extract pair list data
     const int                nri    = nlist.nri;
@@ -243,47 +278,56 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
     gmx::ArrayRef<const int> shift  = nlist.shift;
     gmx::ArrayRef<const int> gid    = nlist.gid;
 
-    const real  lambda_coul   = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)];
-    const real  lambda_vdw    = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)];
-    const auto& scParams      = *ic.softCoreParameters;
-    const real  alpha_coul    = scParams.alphaCoulomb;
-    const real  alpha_vdw     = scParams.alphaVdw;
-    const real  lam_power     = scParams.lambdaPower;
-    const real  sigma6_def    = scParams.sigma6WithInvalidSigma;
-    const real  sigma6_min    = scParams.sigma6Minimum;
-    const bool  doForces      = ((flags & GMX_NONBONDED_DO_FORCE) != 0);
-    const bool  doShiftForces = ((flags & GMX_NONBONDED_DO_SHIFTFORCE) != 0);
-    const bool  doPotential   = ((flags & GMX_NONBONDED_DO_POTENTIAL) != 0);
+    const real  lambda_coul = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)];
+    const real  lambda_vdw  = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)];
+    const auto& scParams    = *ic.softCoreParameters;
+    const real gmx_unused alpha_coul    = scParams.alphaCoulomb;
+    const real gmx_unused alpha_vdw     = scParams.alphaVdw;
+    const real            lam_power     = scParams.lambdaPower;
+    const real gmx_unused sigma6_def    = scParams.sigma6WithInvalidSigma;
+    const real gmx_unused sigma6_min    = scParams.sigma6Minimum;
+    const bool            doForces      = ((flags & GMX_NONBONDED_DO_FORCE) != 0);
+    const bool            doShiftForces = ((flags & GMX_NONBONDED_DO_SHIFTFORCE) != 0);
+    const bool            doPotential   = ((flags & GMX_NONBONDED_DO_POTENTIAL) != 0);
 
     // Extract data from interaction_const_t
-    const real facel           = ic.epsfac;
-    const real rCoulomb        = ic.rcoulomb;
-    const real krf             = ic.reactionFieldCoefficient;
-    const real crf             = ic.reactionFieldShift;
-    const real shLjEwald       = ic.sh_lj_ewald;
-    const real rVdw            = ic.rvdw;
-    const real dispersionShift = ic.dispersion_shift.cpot;
-    const real repulsionShift  = ic.repulsion_shift.cpot;
+    const real            facel           = ic.epsfac;
+    const real            rCoulomb        = ic.rcoulomb;
+    const real            krf             = ic.reactionFieldCoefficient;
+    const real            crf             = ic.reactionFieldShift;
+    const real gmx_unused shLjEwald       = ic.sh_lj_ewald;
+    const real            rVdw            = ic.rvdw;
+    const real            dispersionShift = ic.dispersion_shift.cpot;
+    const real            repulsionShift  = ic.repulsion_shift.cpot;
+    const real            ewaldBeta       = ic.ewaldcoeff_q;
+    real gmx_unused       ewaldLJCoeffSq;
+    real gmx_unused       ewaldLJCoeffSixDivSix;
+    if constexpr (vdwInteractionTypeIsEwald)
+    {
+        ewaldLJCoeffSq        = ic.ewaldcoeff_lj * ic.ewaldcoeff_lj;
+        ewaldLJCoeffSixDivSix = ewaldLJCoeffSq * ewaldLJCoeffSq * ewaldLJCoeffSq / six;
+    }
 
     // Note that the nbnxm kernels do not support Coulomb potential switching at all
     GMX_ASSERT(ic.coulomb_modifier != InteractionModifiers::PotSwitch,
                "Potential switching is not supported for Coulomb with FEP");
 
-    real vdw_swV3, vdw_swV4, vdw_swV5, vdw_swF2, vdw_swF3, vdw_swF4;
-    if (vdwModifierIsPotSwitch)
+    const real      rVdwSwitch = ic.rvdw_switch;
+    real gmx_unused vdw_swV3, vdw_swV4, vdw_swV5, vdw_swF2, vdw_swF3, vdw_swF4;
+    if constexpr (vdwModifierIsPotSwitch)
     {
-        const real d = ic.rvdw - ic.rvdw_switch;
-        vdw_swV3     = -10.0 / (d * d * d);
-        vdw_swV4     = 15.0 / (d * d * d * d);
-        vdw_swV5     = -6.0 / (d * d * d * d * d);
-        vdw_swF2     = -30.0 / (d * d * d);
-        vdw_swF3     = 60.0 / (d * d * d * d);
-        vdw_swF4     = -30.0 / (d * d * d * d * d);
+        const real d = rVdw - rVdwSwitch;
+        vdw_swV3     = -10.0_real / (d * d * d);
+        vdw_swV4     = 15.0_real / (d * d * d * d);
+        vdw_swV5     = -6.0_real / (d * d * d * d * d);
+        vdw_swF2     = -30.0_real / (d * d * d);
+        vdw_swF3     = 60.0_real / (d * d * d * d);
+        vdw_swF4     = -30.0_real / (d * d * d * d * d);
     }
     else
     {
         /* Avoid warnings from stupid compilers (looking at you, Clang!) */
-        vdw_swV3 = vdw_swV4 = vdw_swV5 = vdw_swF2 = vdw_swF3 = vdw_swF4 = 0.0;
+        vdw_swV3 = vdw_swV4 = vdw_swV5 = vdw_swF2 = vdw_swF3 = vdw_swF4 = zero;
     }
 
     NbkernelElecType icoul;
@@ -299,33 +343,11 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
     real rcutoff_max2 = std::max(ic.rcoulomb, ic.rvdw);
     rcutoff_max2      = rcutoff_max2 * rcutoff_max2;
 
-    const real* tab_ewald_F_lj           = nullptr;
-    const real* tab_ewald_V_lj           = nullptr;
-    const real* ewtab                    = nullptr;
-    real        coulombTableScale        = 0;
-    real        coulombTableScaleInvHalf = 0;
-    real        vdwTableScale            = 0;
-    real        vdwTableScaleInvHalf     = 0;
-    real        sh_ewald                 = 0;
-    if (elecInteractionTypeIsEwald || vdwInteractionTypeIsEwald)
+    real gmx_unused sh_ewald = 0;
+    if constexpr (elecInteractionTypeIsEwald || vdwInteractionTypeIsEwald)
     {
         sh_ewald = ic.sh_ewald;
     }
-    if (elecInteractionTypeIsEwald)
-    {
-        const auto& coulombTables = *ic.coulombEwaldTables;
-        ewtab                     = coulombTables.tableFDV0.data();
-        coulombTableScale         = coulombTables.scale;
-        coulombTableScaleInvHalf  = half / coulombTableScale;
-    }
-    if (vdwInteractionTypeIsEwald)
-    {
-        const auto& vdwTables = *ic.vdwEwaldTables;
-        tab_ewald_F_lj        = vdwTables.tableF.data();
-        tab_ewald_V_lj        = vdwTables.tableV.data();
-        vdwTableScale         = vdwTables.scale;
-        vdwTableScaleInvHalf  = half / vdwTableScale;
-    }
 
     /* For Ewald/PME interactions we cannot easily apply the soft-core component to
      * reciprocal space. When we use non-switched Ewald interactions, we
@@ -342,8 +364,8 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
     GMX_RELEASE_ASSERT(!(vdwInteractionTypeIsEwald && vdwModifierIsPotSwitch),
                        "Can not apply soft-core to switched Ewald potentials");
 
-    real dvdlCoul = 0;
-    real dvdlVdw  = 0;
+    RealType dvdlCoul(zero);
+    RealType dvdlVdw(zero);
 
     /* Lambda factor for state A, 1-lambda*/
     real LFC[NSTATES], LFV[NSTATES];
@@ -356,11 +378,11 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
 
     /*derivative of the lambda factor for state A and B */
     real DLF[NSTATES];
-    DLF[STATE_A] = -1;
-    DLF[STATE_B] = 1;
+    DLF[STATE_A] = -one;
+    DLF[STATE_B] = one;
 
-    real           lFacCoul[NSTATES], dlFacCoul[NSTATES], lFacVdw[NSTATES], dlFacVdw[NSTATES];
-    constexpr real sc_r_power = 6.0_real;
+    real gmx_unused lFacCoul[NSTATES], dlFacCoul[NSTATES], lFacVdw[NSTATES], dlFacVdw[NSTATES];
+    constexpr real  sc_r_power = six;
     for (int i = 0; i < NSTATES; i++)
     {
         lFacCoul[i]  = (lam_power == 2 ? (1 - LFC[i]) * (1 - LFC[i]) : (1 - LFC[i]));
@@ -370,57 +392,172 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
     }
 
     // TODO: We should get rid of using pointers to real
-    const real*        x      = coords[0];
-    real* gmx_restrict f      = &(forceWithShiftForces->force()[0][0]);
-    real* gmx_restrict fshift = &(forceWithShiftForces->shiftForces()[0][0]);
+    const real* gmx_restrict x = coords.paddedConstArrayRef().data()[0];
 
     const real rlistSquared = gmx::square(rlist);
 
-    int numExcludedPairsBeyondRlist = 0;
+    bool haveExcludedPairsBeyondRlist = false;
 
     for (int n = 0; n < nri; n++)
     {
-        int npair_within_cutoff = 0;
-
-        const int  is    = shift[n];
-        const int  is3   = DIM * is;
-        const real shX   = shiftvec[is][XX];
-        const real shY   = shiftvec[is][YY];
-        const real shZ   = shiftvec[is][ZZ];
-        const int  nj0   = jindex[n];
-        const int  nj1   = jindex[n + 1];
-        const int  ii    = iinr[n];
-        const int  ii3   = 3 * ii;
-        const real ix    = shX + x[ii3 + 0];
-        const real iy    = shY + x[ii3 + 1];
-        const real iz    = shZ + x[ii3 + 2];
-        const real iqA   = facel * chargeA[ii];
-        const real iqB   = facel * chargeB[ii];
-        const int  ntiA  = 2 * ntype * typeA[ii];
-        const int  ntiB  = 2 * ntype * typeB[ii];
-        real       vCTot = 0;
-        real       vVTot = 0;
-        real       fIX   = 0;
-        real       fIY   = 0;
-        real       fIZ   = 0;
-
-        for (int k = nj0; k < nj1; k++)
+        bool havePairsWithinCutoff = false;
+
+        const int  is   = shift[n];
+        const real shX  = shiftvec[is][XX];
+        const real shY  = shiftvec[is][YY];
+        const real shZ  = shiftvec[is][ZZ];
+        const int  nj0  = jindex[n];
+        const int  nj1  = jindex[n + 1];
+        const int  ii   = iinr[n];
+        const int  ii3  = 3 * ii;
+        const real ix   = shX + x[ii3 + 0];
+        const real iy   = shY + x[ii3 + 1];
+        const real iz   = shZ + x[ii3 + 2];
+        const real iqA  = facel * chargeA[ii];
+        const real iqB  = facel * chargeB[ii];
+        const int  ntiA = ntype * typeA[ii];
+        const int  ntiB = ntype * typeB[ii];
+        RealType   vCTot(0);
+        RealType   vVTot(0);
+        RealType   fIX(0);
+        RealType   fIY(0);
+        RealType   fIZ(0);
+
+#if GMX_SIMD_HAVE_REAL
+        alignas(GMX_SIMD_ALIGNMENT) int preloadIi[DataTypes::simdRealWidth];
+        alignas(GMX_SIMD_ALIGNMENT) int preloadIs[DataTypes::simdRealWidth];
+#else
+        int preloadIi[DataTypes::simdRealWidth];
+        int preloadIs[DataTypes::simdRealWidth];
+#endif
+        for (int s = 0; s < DataTypes::simdRealWidth; s++)
         {
-            int            tj[NSTATES];
-            const int      jnr = jjnr[k];
-            const int      j3  = 3 * jnr;
-            RealType       c6[NSTATES], c12[NSTATES], qq[NSTATES], vCoul[NSTATES], vVdw[NSTATES];
-            RealType       r, rInv, rp, rpm2;
-            RealType       alphaVdwEff, alphaCoulEff, sigma6[NSTATES];
-            const RealType dX  = ix - x[j3];
-            const RealType dY  = iy - x[j3 + 1];
-            const RealType dZ  = iz - x[j3 + 2];
+            preloadIi[s] = ii;
+            preloadIs[s] = shift[n];
+        }
+        IntType ii_s = gmx::load<IntType>(preloadIi);
+
+        for (int k = nj0; k < nj1; k += DataTypes::simdRealWidth)
+        {
+            RealType r, rInv;
+
+#if GMX_SIMD_HAVE_REAL
+            alignas(GMX_SIMD_ALIGNMENT) real    preloadPairIsValid[DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) real    preloadPairIncluded[DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) int32_t preloadJnr[DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) int32_t typeIndices[NSTATES][DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) real    preloadQq[NSTATES][DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) real gmx_unused preloadSigma6[NSTATES][DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) real gmx_unused preloadAlphaVdwEff[DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) real gmx_unused preloadAlphaCoulEff[DataTypes::simdRealWidth];
+            alignas(GMX_SIMD_ALIGNMENT) real preloadLjPmeC6Grid[NSTATES][DataTypes::simdRealWidth];
+#else
+            real            preloadPairIsValid[DataTypes::simdRealWidth];
+            real            preloadPairIncluded[DataTypes::simdRealWidth];
+            int             preloadJnr[DataTypes::simdRealWidth];
+            int             typeIndices[NSTATES][DataTypes::simdRealWidth];
+            real            preloadQq[NSTATES][DataTypes::simdRealWidth];
+            real gmx_unused preloadSigma6[NSTATES][DataTypes::simdRealWidth];
+            real gmx_unused preloadAlphaVdwEff[DataTypes::simdRealWidth];
+            real gmx_unused preloadAlphaCoulEff[DataTypes::simdRealWidth];
+            real            preloadLjPmeC6Grid[NSTATES][DataTypes::simdRealWidth];
+#endif
+            for (int s = 0; s < DataTypes::simdRealWidth; s++)
+            {
+                if (k + s < nj1)
+                {
+                    preloadPairIsValid[s] = true;
+                    /* Check if this pair on the exclusions list.*/
+                    preloadPairIncluded[s]  = (nlist.excl_fep.empty() || nlist.excl_fep[k + s]);
+                    const int jnr           = jjnr[k + s];
+                    preloadJnr[s]           = jnr;
+                    typeIndices[STATE_A][s] = ntiA + typeA[jnr];
+                    typeIndices[STATE_B][s] = ntiB + typeB[jnr];
+                    preloadQq[STATE_A][s]   = iqA * chargeA[jnr];
+                    preloadQq[STATE_B][s]   = iqB * chargeB[jnr];
+
+                    for (int i = 0; i < NSTATES; i++)
+                    {
+                        if constexpr (vdwInteractionTypeIsEwald)
+                        {
+                            preloadLjPmeC6Grid[i][s] = nbfp_grid[2 * typeIndices[i][s]];
+                        }
+                        else
+                        {
+                            preloadLjPmeC6Grid[i][s] = 0;
+                        }
+                        if constexpr (useSoftCore)
+                        {
+                            const real c6  = nbfp[2 * typeIndices[i][s]];
+                            const real c12 = nbfp[2 * typeIndices[i][s] + 1];
+                            if (c6 > 0 && c12 > 0)
+                            {
+                                /* c12 is stored scaled with 12.0 and c6 is scaled with 6.0 - correct for this */
+                                preloadSigma6[i][s] = 0.5_real * c12 / c6;
+                                if (preloadSigma6[i][s]
+                                    < sigma6_min) /* for disappearing coul and vdw with soft core at the same time */
+                                {
+                                    preloadSigma6[i][s] = sigma6_min;
+                                }
+                            }
+                            else
+                            {
+                                preloadSigma6[i][s] = sigma6_def;
+                            }
+                        }
+                    }
+                    if constexpr (useSoftCore)
+                    {
+                        /* only use softcore if one of the states has a zero endstate - softcore is for avoiding infinities!*/
+                        const real c12A = nbfp[2 * typeIndices[STATE_A][s] + 1];
+                        const real c12B = nbfp[2 * typeIndices[STATE_B][s] + 1];
+                        if (c12A > 0 && c12B > 0)
+                        {
+                            preloadAlphaVdwEff[s]  = 0;
+                            preloadAlphaCoulEff[s] = 0;
+                        }
+                        else
+                        {
+                            preloadAlphaVdwEff[s]  = alpha_vdw;
+                            preloadAlphaCoulEff[s] = alpha_coul;
+                        }
+                    }
+                }
+                else
+                {
+                    preloadJnr[s]          = jjnr[k];
+                    preloadPairIsValid[s]  = false;
+                    preloadPairIncluded[s] = false;
+                    preloadAlphaVdwEff[s]  = 0;
+                    preloadAlphaCoulEff[s] = 0;
+
+                    for (int i = 0; i < NSTATES; i++)
+                    {
+                        typeIndices[STATE_A][s]  = ntiA + typeA[jjnr[k]];
+                        typeIndices[STATE_B][s]  = ntiB + typeB[jjnr[k]];
+                        preloadLjPmeC6Grid[i][s] = 0;
+                        preloadQq[i][s]          = 0;
+                        preloadSigma6[i][s]      = 0;
+                    }
+                }
+            }
+
+            RealType jx, jy, jz;
+            gmx::gatherLoadUTranspose<3>(reinterpret_cast<const real*>(x), preloadJnr, &jx, &jy, &jz);
+
+            const RealType pairIsValid   = gmx::load<RealType>(preloadPairIsValid);
+            const RealType pairIncluded  = gmx::load<RealType>(preloadPairIncluded);
+            const BoolType bPairIncluded = (pairIncluded != zero);
+            const BoolType bPairExcluded = (pairIncluded == zero && pairIsValid != zero);
+
+            const RealType dX  = ix - jx;
+            const RealType dY  = iy - jy;
+            const RealType dZ  = iz - jz;
             const RealType rSq = dX * dX + dY * dY + dZ * dZ;
-            RealType       fScalC[NSTATES], fScalV[NSTATES];
-            /* Check if this pair on the exlusions list.*/
-            const bool bPairIncluded = nlist.excl_fep.empty() || nlist.excl_fep[k];
 
-            if (rSq >= rcutoff_max2 && bPairIncluded)
+            BoolType withinCutoffMask = (rSq < rcutoff_max2);
+
+            if (!gmx::anyTrue(withinCutoffMask || bPairExcluded))
             {
                 /* We save significant time by skipping all code below.
                  * Note that with soft-core interactions, the actual cut-off
@@ -430,38 +567,55 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                  * when using Ewald: the reciprocal-space
                  * Ewald component still needs to be subtracted.
                  */
-
                 continue;
             }
-            npair_within_cutoff++;
-
-            if (rSq > rlistSquared)
+            else
             {
-                numExcludedPairsBeyondRlist++;
+                havePairsWithinCutoff = true;
             }
 
-            if (rSq > 0)
+            if (gmx::anyTrue(rlistSquared < rSq && bPairExcluded))
             {
-                /* Note that unlike in the nbnxn kernels, we do not need
-                 * to clamp the value of rSq before taking the invsqrt
-                 * to avoid NaN in the LJ calculation, since here we do
-                 * not calculate LJ interactions when C6 and C12 are zero.
-                 */
+                haveExcludedPairsBeyondRlist = true;
+            }
 
-                rInv = gmx::invsqrt(rSq);
-                r    = rSq * rInv;
+            const IntType  jnr_s    = gmx::load<IntType>(preloadJnr);
+            const BoolType bIiEqJnr = gmx::cvtIB2B(ii_s == jnr_s);
+
+            RealType            c6[NSTATES];
+            RealType            c12[NSTATES];
+            RealType gmx_unused sigma6[NSTATES];
+            RealType            qq[NSTATES];
+            RealType gmx_unused ljPmeC6Grid[NSTATES];
+            RealType gmx_unused alphaVdwEff;
+            RealType gmx_unused alphaCoulEff;
+            for (int i = 0; i < NSTATES; i++)
+            {
+                gmx::gatherLoadTranspose<2>(nbfp.data(), typeIndices[i], &c6[i], &c12[i]);
+                qq[i]          = gmx::load<RealType>(preloadQq[i]);
+                ljPmeC6Grid[i] = gmx::load<RealType>(preloadLjPmeC6Grid[i]);
+                if constexpr (useSoftCore)
+                {
+                    sigma6[i] = gmx::load<RealType>(preloadSigma6[i]);
+                }
             }
-            else
+            if constexpr (useSoftCore)
             {
-                /* The force at r=0 is zero, because of symmetry.
-                 * But note that the potential is in general non-zero,
-                 * since the soft-cored r will be non-zero.
-                 */
-                rInv = 0;
-                r    = 0;
+                alphaVdwEff  = gmx::load<RealType>(preloadAlphaVdwEff);
+                alphaCoulEff = gmx::load<RealType>(preloadAlphaCoulEff);
             }
 
-            if (useSoftCore)
+            BoolType rSqValid = (zero < rSq);
+
+            /* The force at r=0 is zero, because of symmetry.
+             * But note that the potential is in general non-zero,
+             * since the soft-cored r will be non-zero.
+             */
+            rInv = gmx::maskzInvsqrt(rSq, rSqValid);
+            r    = rSq * rInv;
+
+            RealType gmx_unused rp, rpm2;
+            if constexpr (useSoftCore)
             {
                 rpm2 = rSq * rSq;  /* r4 */
                 rp   = rpm2 * rSq; /* r6 */
@@ -473,79 +627,45 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                  * the simplest math and cheapest code.
                  */
                 rpm2 = rInv * rInv;
-                rp   = 1;
+                rp   = one;
             }
 
-            RealType fScal = 0;
+            RealType fScal(0);
 
-            qq[STATE_A] = iqA * chargeA[jnr];
-            qq[STATE_B] = iqB * chargeB[jnr];
-
-            tj[STATE_A] = ntiA + 2 * typeA[jnr];
-            tj[STATE_B] = ntiB + 2 * typeB[jnr];
-
-            if (bPairIncluded)
+            /* The following block is masked to only calculate values having bPairIncluded. If
+             * bPairIncluded is true then withinCutoffMask must also be true. */
+            if (gmx::anyTrue(withinCutoffMask && bPairIncluded))
             {
-                c6[STATE_A] = nbfp[tj[STATE_A]];
-                c6[STATE_B] = nbfp[tj[STATE_B]];
-
+                RealType fScalC[NSTATES], fScalV[NSTATES];
+                RealType vCoul[NSTATES], vVdw[NSTATES];
                 for (int i = 0; i < NSTATES; i++)
                 {
-                    c12[i] = nbfp[tj[i] + 1];
-                    if (useSoftCore)
+                    fScalC[i] = zero;
+                    fScalV[i] = zero;
+                    vCoul[i]  = zero;
+                    vVdw[i]   = zero;
+
+                    RealType gmx_unused rInvC, rInvV, rC, rV, rPInvC, rPInvV;
+
+                    /* The following block is masked to require (qq[i] != 0 || c6[i] != 0 || c12[i]
+                     * != 0) in addition to bPairIncluded, which in turn requires withinCutoffMask. */
+                    BoolType nonZeroState = ((qq[i] != zero || c6[i] != zero || c12[i] != zero)
+                                             && bPairIncluded && withinCutoffMask);
+                    if (gmx::anyTrue(nonZeroState))
                     {
-                        if ((c6[i] > 0) && (c12[i] > 0))
-                        {
-                            /* c12 is stored scaled with 12.0 and c6 is scaled with 6.0 - correct for this */
-                            sigma6[i] = half * c12[i] / c6[i];
-                            if (sigma6[i] < sigma6_min) /* for disappearing coul and vdw with soft core at the same time */
-                            {
-                                sigma6[i] = sigma6_min;
-                            }
-                        }
-                        else
+                        if constexpr (useSoftCore)
                         {
-                            sigma6[i] = sigma6_def;
-                        }
-                    }
-                }
-
-                if (useSoftCore)
-                {
-                    /* only use softcore if one of the states has a zero endstate - softcore is for avoiding infinities!*/
-                    if ((c12[STATE_A] > 0) && (c12[STATE_B] > 0))
-                    {
-                        alphaVdwEff  = 0;
-                        alphaCoulEff = 0;
-                    }
-                    else
-                    {
-                        alphaVdwEff  = alpha_vdw;
-                        alphaCoulEff = alpha_coul;
-                    }
-                }
-
-                for (int i = 0; i < NSTATES; i++)
-                {
-                    fScalC[i] = 0;
-                    fScalV[i] = 0;
-                    vCoul[i]  = 0;
-                    vVdw[i]   = 0;
-
-                    RealType rInvC, rInvV, rC, rV, rPInvC, rPInvV;
+                            RealType divisor      = (alphaCoulEff * lFacCoul[i] * sigma6[i] + rp);
+                            BoolType validDivisor = (zero < divisor);
+                            rPInvC                = gmx::maskzInv(divisor, validDivisor);
+                            pthRoot(rPInvC, &rInvC, &rC, validDivisor);
 
-                    /* Only spend time on A or B state if it is non-zero */
-                    if ((qq[i] != 0) || (c6[i] != 0) || (c12[i] != 0))
-                    {
-                        /* this section has to be inside the loop because of the dependence on sigma6 */
-                        if (useSoftCore)
-                        {
-                            rPInvC = one / (alphaCoulEff * lFacCoul[i] * sigma6[i] + rp);
-                            pthRoot(rPInvC, &rInvC, &rC);
-                            if (scLambdasOrAlphasDiffer)
+                            if constexpr (scLambdasOrAlphasDiffer)
                             {
-                                rPInvV = one / (alphaVdwEff * lFacVdw[i] * sigma6[i] + rp);
-                                pthRoot(rPInvV, &rInvV, &rV);
+                                RealType divisor      = (alphaVdwEff * lFacVdw[i] * sigma6[i] + rp);
+                                BoolType validDivisor = (zero < divisor);
+                                rPInvV                = gmx::maskzInv(divisor, validDivisor);
+                                pthRoot(rPInvV, &rInvV, &rV, validDivisor);
                             }
                             else
                             {
@@ -557,25 +677,31 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                         }
                         else
                         {
-                            rPInvC = 1;
+                            rPInvC = one;
                             rInvC  = rInv;
                             rC     = r;
 
-                            rPInvV = 1;
+                            rPInvV = one;
                             rInvV  = rInv;
                             rV     = r;
                         }
 
-                        /* Only process the coulomb interactions if we have charges,
-                         * and if we either include all entries in the list (no cutoff
+                        /* Only process the coulomb interactions if we either
+                         * include all entries in the list (no cutoff
                          * used in the kernel), or if we are within the cutoff.
                          */
-                        bool computeElecInteraction = (elecInteractionTypeIsEwald && r < rCoulomb)
-                                                      || (!elecInteractionTypeIsEwald && rC < rCoulomb);
-
-                        if ((qq[i] != 0) && computeElecInteraction)
+                        BoolType computeElecInteraction;
+                        if constexpr (elecInteractionTypeIsEwald)
                         {
-                            if (elecInteractionTypeIsEwald)
+                            computeElecInteraction = (r < rCoulomb && qq[i] != zero && bPairIncluded);
+                        }
+                        else
+                        {
+                            computeElecInteraction = (rC < rCoulomb && qq[i] != zero && bPairIncluded);
+                        }
+                        if (gmx::anyTrue(computeElecInteraction))
+                        {
+                            if constexpr (elecInteractionTypeIsEwald)
                             {
                                 vCoul[i]  = ewaldPotential(qq[i], rInvC, sh_ewald);
                                 fScalC[i] = ewaldScalarForce(qq[i], rInvC);
@@ -585,19 +711,30 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                                 vCoul[i]  = reactionFieldPotential(qq[i], rInvC, rC, krf, crf);
                                 fScalC[i] = reactionFieldScalarForce(qq[i], rInvC, rC, krf, two);
                             }
+
+                            vCoul[i]  = gmx::selectByMask(vCoul[i], computeElecInteraction);
+                            fScalC[i] = gmx::selectByMask(fScalC[i], computeElecInteraction);
                         }
 
-                        /* Only process the VDW interactions if we have
-                         * some non-zero parameters, and if we either
+                        /* Only process the VDW interactions if we either
                          * include all entries in the list (no cutoff used
                          * in the kernel), or if we are within the cutoff.
                          */
-                        bool computeVdwInteraction = (vdwInteractionTypeIsEwald && r < rVdw)
-                                                     || (!vdwInteractionTypeIsEwald && rV < rVdw);
-                        if ((c6[i] != 0 || c12[i] != 0) && computeVdwInteraction)
+                        BoolType computeVdwInteraction;
+                        if constexpr (vdwInteractionTypeIsEwald)
+                        {
+                            computeVdwInteraction =
+                                    (r < rVdw && (c6[i] != 0 || c12[i] != 0) && bPairIncluded);
+                        }
+                        else
+                        {
+                            computeVdwInteraction =
+                                    (rV < rVdw && (c6[i] != 0 || c12[i] != 0) && bPairIncluded);
+                        }
+                        if (gmx::anyTrue(computeVdwInteraction))
                         {
                             RealType rInv6;
-                            if (useSoftCore)
+                            if constexpr (useSoftCore)
                             {
                                 rInv6 = rPInvV;
                             }
@@ -612,26 +749,33 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                                     vVdw6, vVdw12, c6[i], c12[i], repulsionShift, dispersionShift, oneSixth, oneTwelfth);
                             fScalV[i] = lennardJonesScalarForce(vVdw6, vVdw12);
 
-                            if (vdwInteractionTypeIsEwald)
+                            if constexpr (vdwInteractionTypeIsEwald)
                             {
                                 /* Subtract the grid potential at the cut-off */
-                                vVdw[i] += ewaldLennardJonesGridSubtract(
-                                        nbfp_grid[tj[i]], shLjEwald, oneSixth);
+                                vVdw[i] = vVdw[i]
+                                          + gmx::selectByMask(ewaldLennardJonesGridSubtract(
+                                                                      ljPmeC6Grid[i], shLjEwald, oneSixth),
+                                                              computeVdwInteraction);
                             }
 
-                            if (vdwModifierIsPotSwitch)
+                            if constexpr (vdwModifierIsPotSwitch)
                             {
-                                RealType d        = rV - ic.rvdw_switch;
-                                d                 = (d > zero) ? d : zero;
-                                const RealType d2 = d * d;
+                                RealType d             = rV - rVdwSwitch;
+                                BoolType zeroMask      = zero < d;
+                                BoolType potSwitchMask = rV < rVdw;
+                                d                      = gmx::selectByMask(d, zeroMask);
+                                const RealType d2      = d * d;
                                 const RealType sw =
                                         one + d2 * d * (vdw_swV3 + d * (vdw_swV4 + d * vdw_swV5));
                                 const RealType dsw = d2 * (vdw_swF2 + d * (vdw_swF3 + d * vdw_swF4));
 
                                 fScalV[i] = potSwitchScalarForceMod(
-                                        fScalV[i], vVdw[i], sw, rV, rVdw, dsw, zero);
-                                vVdw[i] = potSwitchPotentialMod(vVdw[i], sw, rV, rVdw, zero);
+                                        fScalV[i], vVdw[i], sw, rV, dsw, potSwitchMask);
+                                vVdw[i] = potSwitchPotentialMod(vVdw[i], sw, potSwitchMask);
                             }
+
+                            vVdw[i]   = gmx::selectByMask(vVdw[i], computeVdwInteraction);
+                            fScalV[i] = gmx::selectByMask(fScalV[i], computeVdwInteraction);
                         }
 
                         /* fScalC (and fScalV) now contain: dV/drC * rC
@@ -639,57 +783,69 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                          * Further down we first multiply by r^p-2 and then by
                          * the vector r, which in total gives: dV/drC * (r/rC)^1-p
                          */
-                        fScalC[i] *= rPInvC;
-                        fScalV[i] *= rPInvV;
-                    }
-                } // end for (int i = 0; i < NSTATES; i++)
-
-                /* Assemble A and B states */
-                for (int i = 0; i < NSTATES; i++)
+                        fScalC[i] = fScalC[i] * rPInvC;
+                        fScalV[i] = fScalV[i] * rPInvV;
+                    } // end of block requiring nonZeroState
+                }     // end for (int i = 0; i < NSTATES; i++)
+
+                /* Assemble A and B states. */
+                BoolType assembleStates = (bPairIncluded && withinCutoffMask);
+                if (gmx::anyTrue(assembleStates))
                 {
-                    vCTot += LFC[i] * vCoul[i];
-                    vVTot += LFV[i] * vVdw[i];
+                    for (int i = 0; i < NSTATES; i++)
+                    {
+                        vCTot = vCTot + LFC[i] * vCoul[i];
+                        vVTot = vVTot + LFV[i] * vVdw[i];
 
-                    fScal += LFC[i] * fScalC[i] * rpm2;
-                    fScal += LFV[i] * fScalV[i] * rpm2;
+                        fScal = fScal + LFC[i] * fScalC[i] * rpm2;
+                        fScal = fScal + LFV[i] * fScalV[i] * rpm2;
 
-                    if (useSoftCore)
-                    {
-                        dvdlCoul += vCoul[i] * DLF[i]
-                                    + LFC[i] * alphaCoulEff * dlFacCoul[i] * fScalC[i] * sigma6[i];
-                        dvdlVdw += vVdw[i] * DLF[i]
-                                   + LFV[i] * alphaVdwEff * dlFacVdw[i] * fScalV[i] * sigma6[i];
-                    }
-                    else
-                    {
-                        dvdlCoul += vCoul[i] * DLF[i];
-                        dvdlVdw += vVdw[i] * DLF[i];
+                        if constexpr (useSoftCore)
+                        {
+                            dvdlCoul = dvdlCoul + vCoul[i] * DLF[i]
+                                       + LFC[i] * alphaCoulEff * dlFacCoul[i] * fScalC[i] * sigma6[i];
+                            dvdlVdw = dvdlVdw + vVdw[i] * DLF[i]
+                                      + LFV[i] * alphaVdwEff * dlFacVdw[i] * fScalV[i] * sigma6[i];
+                        }
+                        else
+                        {
+                            dvdlCoul = dvdlCoul + vCoul[i] * DLF[i];
+                            dvdlVdw  = dvdlVdw + vVdw[i] * DLF[i];
+                        }
                     }
                 }
-            } // end if (bPairIncluded)
-            else if (icoul == NbkernelElecType::ReactionField)
+            } // end of block requiring bPairIncluded && withinCutoffMask
+            /* In the following block bPairIncluded should be false in the masks. */
+            if (icoul == NbkernelElecType::ReactionField)
             {
-                /* For excluded pairs, which are only in this pair list when
-                 * using the Verlet scheme, we don't use soft-core.
-                 * As there is no singularity, there is no need for soft-core.
-                 */
-                const real FF = -two * krf;
-                RealType   VV = krf * rSq - crf;
+                const BoolType computeReactionField = bPairExcluded;
 
-                if (ii == jnr)
+                if (gmx::anyTrue(computeReactionField))
                 {
-                    VV *= half;
-                }
+                    /* For excluded pairs we don't use soft-core.
+                     * As there is no singularity, there is no need for soft-core.
+                     */
+                    const RealType FF = -two * krf;
+                    RealType       VV = krf * rSq - crf;
 
-                for (int i = 0; i < NSTATES; i++)
-                {
-                    vCTot += LFC[i] * qq[i] * VV;
-                    fScal += LFC[i] * qq[i] * FF;
-                    dvdlCoul += DLF[i] * qq[i] * VV;
+                    /* If ii == jnr the i particle (ii) has itself (jnr)
+                     * in its neighborlist. This corresponds to a self-interaction
+                     * that will occur twice. Scale it down by 50% to only include
+                     * it once.
+                     */
+                    VV = VV * gmx::blend(one, half, bIiEqJnr);
+
+                    for (int i = 0; i < NSTATES; i++)
+                    {
+                        vCTot = vCTot + gmx::selectByMask(LFC[i] * qq[i] * VV, computeReactionField);
+                        fScal = fScal + gmx::selectByMask(LFC[i] * qq[i] * FF, computeReactionField);
+                        dvdlCoul = dvdlCoul + gmx::selectByMask(DLF[i] * qq[i] * VV, computeReactionField);
+                    }
                 }
             }
 
-            if (elecInteractionTypeIsEwald && (r < rCoulomb || !bPairIncluded))
+            const BoolType computeElecEwaldInteraction = (bPairExcluded || r < rCoulomb);
+            if (elecInteractionTypeIsEwald && gmx::anyTrue(computeElecEwaldInteraction))
             {
                 /* See comment in the preamble. When using Ewald interactions
                  * (unless we use a switch modifier) we subtract the reciprocal-space
@@ -699,39 +855,33 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                  * the softcore to the entire electrostatic interaction,
                  * including the reciprocal-space component.
                  */
-                real v_lr, f_lr;
+                RealType v_lr, f_lr;
 
-                const RealType ewrt   = r * coulombTableScale;
-                IntType        ewitab = static_cast<IntType>(ewrt);
-                const RealType eweps  = ewrt - ewitab;
-                ewitab                = 4 * ewitab;
-                f_lr                  = ewtab[ewitab] + eweps * ewtab[ewitab + 1];
-                v_lr = (ewtab[ewitab + 2] - coulombTableScaleInvHalf * eweps * (ewtab[ewitab] + f_lr));
-                f_lr *= rInv;
+                pmeCoulombCorrectionVF(rSq, ewaldBeta, &v_lr, &f_lr, rSqValid);
+                f_lr = f_lr * rInv * rInv;
 
                 /* Note that any possible Ewald shift has already been applied in
                  * the normal interaction part above.
                  */
 
-                if (ii == jnr)
-                {
-                    /* If we get here, the i particle (ii) has itself (jnr)
-                     * in its neighborlist. This can only happen with the Verlet
-                     * scheme, and corresponds to a self-interaction that will
-                     * occur twice. Scale it down by 50% to only include it once.
-                     */
-                    v_lr *= half;
-                }
+                /* If ii == jnr the i particle (ii) has itself (jnr)
+                 * in its neighborlist. This corresponds to a self-interaction
+                 * that will occur twice. Scale it down by 50% to only include
+                 * it once.
+                 */
+                v_lr = v_lr * gmx::blend(one, half, bIiEqJnr);
 
                 for (int i = 0; i < NSTATES; i++)
                 {
-                    vCTot -= LFC[i] * qq[i] * v_lr;
-                    fScal -= LFC[i] * qq[i] * f_lr;
-                    dvdlCoul -= (DLF[i] * qq[i]) * v_lr;
+                    vCTot = vCTot - gmx::selectByMask(LFC[i] * qq[i] * v_lr, computeElecEwaldInteraction);
+                    fScal = fScal - gmx::selectByMask(LFC[i] * qq[i] * f_lr, computeElecEwaldInteraction);
+                    dvdlCoul = dvdlCoul
+                               - gmx::selectByMask(DLF[i] * qq[i] * v_lr, computeElecEwaldInteraction);
                 }
             }
 
-            if (vdwInteractionTypeIsEwald && (r < rVdw || !bPairIncluded))
+            const BoolType computeVdwEwaldInteraction = (bPairExcluded || r < rVdw);
+            if (vdwInteractionTypeIsEwald && gmx::anyTrue(computeVdwEwaldInteraction))
             {
                 /* See comment in the preamble. When using LJ-Ewald interactions
                  * (unless we use a switch modifier) we subtract the reciprocal-space
@@ -741,147 +891,105 @@ static void nb_free_energy_kernel(const t_nblist&                nlist,
                  * the softcore to the entire VdW interaction,
                  * including the reciprocal-space component.
                  */
-                /* We could also use the analytical form here
-                 * iso a table, but that can cause issues for
-                 * r close to 0 for non-interacting pairs.
-                 */
 
-                const RealType rs   = rSq * rInv * vdwTableScale;
-                const IntType  ri   = static_cast<IntType>(rs);
-                const RealType frac = rs - ri;
-                const RealType f_lr = (1 - frac) * tab_ewald_F_lj[ri] + frac * tab_ewald_F_lj[ri + 1];
-                /* TODO: Currently the Ewald LJ table does not contain
-                 * the factor 1/6, we should add this.
-                 */
-                const RealType FF = f_lr * rInv / six;
-                RealType       VV =
-                        (tab_ewald_V_lj[ri] - vdwTableScaleInvHalf * frac * (tab_ewald_F_lj[ri] + f_lr))
-                        / six;
-
-                if (ii == jnr)
-                {
-                    /* If we get here, the i particle (ii) has itself (jnr)
-                     * in its neighborlist. This can only happen with the Verlet
-                     * scheme, and corresponds to a self-interaction that will
-                     * occur twice. Scale it down by 50% to only include it once.
-                     */
-                    VV *= half;
-                }
+                RealType v_lr, f_lr;
+                pmeLJCorrectionVF(
+                        rInv, rSq, ewaldLJCoeffSq, ewaldLJCoeffSixDivSix, &v_lr, &f_lr, computeVdwEwaldInteraction, bIiEqJnr);
+                v_lr = v_lr * oneSixth;
 
                 for (int i = 0; i < NSTATES; i++)
                 {
-                    const real c6grid = nbfp_grid[tj[i]];
-                    vVTot += LFV[i] * c6grid * VV;
-                    fScal += LFV[i] * c6grid * FF;
-                    dvdlVdw += (DLF[i] * c6grid) * VV;
+                    vVTot = vVTot + gmx::selectByMask(LFV[i] * ljPmeC6Grid[i] * v_lr, computeVdwEwaldInteraction);
+                    fScal = fScal + gmx::selectByMask(LFV[i] * ljPmeC6Grid[i] * f_lr, computeVdwEwaldInteraction);
+                    dvdlVdw = dvdlVdw + gmx::selectByMask(DLF[i] * ljPmeC6Grid[i] * v_lr, computeVdwEwaldInteraction);
                 }
             }
 
-            if (doForces)
+            if (doForces && gmx::anyTrue(fScal != zero))
             {
-                const real tX = fScal * dX;
-                const real tY = fScal * dY;
-                const real tZ = fScal * dZ;
-                fIX           = fIX + tX;
-                fIY           = fIY + tY;
-                fIZ           = fIZ + tZ;
-                /* OpenMP atomics are expensive, but this kernels is also
-                 * expensive, so we can take this hit, instead of using
-                 * thread-local output buffers and extra reduction.
-                 *
-                 * All the OpenMP regions in this file are trivial and should
-                 * not throw, so no need for try/catch.
-                 */
-#pragma omp atomic
-                f[j3] -= tX;
-#pragma omp atomic
-                f[j3 + 1] -= tY;
-#pragma omp atomic
-                f[j3 + 2] -= tZ;
+                const RealType tX = fScal * dX;
+                const RealType tY = fScal * dY;
+                const RealType tZ = fScal * dZ;
+                fIX               = fIX + tX;
+                fIY               = fIY + tY;
+                fIZ               = fIZ + tZ;
+
+                gmx::transposeScatterDecrU<3>(
+                        reinterpret_cast<real*>(threadForceBuffer), preloadJnr, tX, tY, tZ);
             }
-        } // end for (int k = nj0; k < nj1; k++)
-
-        /* The atomics below are expensive with many OpenMP threads.
-         * Here unperturbed i-particles will usually only have a few
-         * (perturbed) j-particles in the list. Thus with a buffered list
-         * we can skip a significant number of i-reductions with a check.
-         */
-        if (npair_within_cutoff > 0)
+        } // end for (int k = nj0; k < nj1; k += DataTypes::simdRealWidth)
+
+        if (havePairsWithinCutoff)
         {
             if (doForces)
             {
-#pragma omp atomic
-                f[ii3] += fIX;
-#pragma omp atomic
-                f[ii3 + 1] += fIY;
-#pragma omp atomic
-                f[ii3 + 2] += fIZ;
+                gmx::transposeScatterIncrU<3>(
+                        reinterpret_cast<real*>(threadForceBuffer), preloadIi, fIX, fIY, fIZ);
             }
             if (doShiftForces)
             {
-#pragma omp atomic
-                fshift[is3] += fIX;
-#pragma omp atomic
-                fshift[is3 + 1] += fIY;
-#pragma omp atomic
-                fshift[is3 + 2] += fIZ;
+                gmx::transposeScatterIncrU<3>(
+                        reinterpret_cast<real*>(threadForceShiftBuffer), preloadIs, fIX, fIY, fIZ);
             }
             if (doPotential)
             {
                 int ggid = gid[n];
-#pragma omp atomic
-                energygrp_elec[ggid] += vCTot;
-#pragma omp atomic
-                energygrp_vdw[ggid] += vVTot;
+                threadVc[ggid] += gmx::reduce(vCTot);
+                threadVv[ggid] += gmx::reduce(vVTot);
             }
         }
     } // end for (int n = 0; n < nri; n++)
 
-#pragma omp atomic
-    dvdl[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)] += dvdlCoul;
-#pragma omp atomic
-    dvdl[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)] += dvdlVdw;
+    if (gmx::anyTrue(dvdlCoul != zero))
+    {
+        threadDvdl[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)] += gmx::reduce(dvdlCoul);
+    }
+    if (gmx::anyTrue(dvdlVdw != zero))
+    {
+        threadDvdl[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)] += gmx::reduce(dvdlVdw);
+    }
 
     /* Estimate flops, average for free energy stuff:
      * 12  flops per outer iteration
      * 150 flops per inner iteration
+     * TODO: Update the number of flops and/or use different counts for different code paths.
      */
     atomicNrnbIncrement(nrnb, eNR_NBKERNEL_FREE_ENERGY, nlist.nri * 12 + nlist.jindex[nri] * 150);
 
-    if (numExcludedPairsBeyondRlist > 0)
+    if (haveExcludedPairsBeyondRlist > 0)
     {
         gmx_fatal(FARGS,
-                  "There are %d perturbed non-bonded pair interactions beyond the pair-list cutoff "
+                  "There are perturbed non-bonded pair interactions beyond the pair-list cutoff "
                   "of %g nm, which is not supported. This can happen because the system is "
                   "unstable or because intra-molecular interactions at long distances are "
                   "excluded. If the "
                   "latter is the case, you can try to increase nstlist or rlist to avoid this."
                   "The error is likely triggered by the use of couple-intramol=no "
                   "and the maximal distance in the decoupled molecule exceeding rlist.",
-                  numExcludedPairsBeyondRlist,
                   rlist);
     }
 }
 
-typedef void (*KernelFunction)(const t_nblist&                nlist,
-                               gmx::ArrayRef<const gmx::RVec> coords,
-                               gmx::ForceWithShiftForces*     forceWithShiftForces,
-                               const int                      ntype,
-                               const real                     rlist,
-                               const interaction_const_t&     ic,
-                               gmx::ArrayRef<const gmx::RVec> shiftvec,
-                               gmx::ArrayRef<const real>      nbfp,
-                               gmx::ArrayRef<const real>      nbfp_grid,
-                               gmx::ArrayRef<const real>      chargeA,
-                               gmx::ArrayRef<const real>      chargeB,
-                               gmx::ArrayRef<const int>       typeA,
-                               gmx::ArrayRef<const int>       typeB,
-                               int                            flags,
-                               gmx::ArrayRef<const real>      lambda,
-                               gmx::ArrayRef<real>            dvdl,
-                               gmx::ArrayRef<real>            energygrp_elec,
-                               gmx::ArrayRef<real>            energygrp_vdw,
-                               t_nrnb* gmx_restrict           nrnb);
+typedef void (*KernelFunction)(const t_nblist&                           nlist,
+                               const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
+                               const int                                 ntype,
+                               const real                                rlist,
+                               const interaction_const_t&                ic,
+                               gmx::ArrayRef<const gmx::RVec>            shiftvec,
+                               gmx::ArrayRef<const real>                 nbfp,
+                               gmx::ArrayRef<const real>                 nbfp_grid,
+                               gmx::ArrayRef<const real>                 chargeA,
+                               gmx::ArrayRef<const real>                 chargeB,
+                               gmx::ArrayRef<const int>                  typeA,
+                               gmx::ArrayRef<const int>                  typeB,
+                               int                                       flags,
+                               gmx::ArrayRef<const real>                 lambda,
+                               t_nrnb* gmx_restrict                      nrnb,
+                               gmx::RVec*                                threadForceBuffer,
+                               rvec*                                     threadForceShiftBuffer,
+                               gmx::ArrayRef<real>                       threadVc,
+                               gmx::ArrayRef<real>                       threadVv,
+                               gmx::ArrayRef<real>                       threadDvdl);
 
 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
 static KernelFunction dispatchKernelOnUseSimd(const bool useSimd)
@@ -889,8 +997,7 @@ static KernelFunction dispatchKernelOnUseSimd(const bool useSimd)
     if (useSimd)
     {
 #if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS && GMX_USE_SIMD_KERNELS
-        /* FIXME: Here SimdDataTypes should be used to enable SIMD. So far, the code in nb_free_energy_kernel is not adapted to SIMD */
-        return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
+        return (nb_free_energy_kernel<SimdDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
 #else
         return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
 #endif
@@ -996,26 +1103,27 @@ static KernelFunction dispatchKernel(const bool                 scLambdasOrAlpha
 }
 
 
-void gmx_nb_free_energy_kernel(const t_nblist&                nlist,
-                               gmx::ArrayRef<const gmx::RVec> coords,
-                               gmx::ForceWithShiftForces*     ff,
-                               const bool                     useSimd,
-                               const int                      ntype,
-                               const real                     rlist,
-                               const interaction_const_t&     ic,
-                               gmx::ArrayRef<const gmx::RVec> shiftvec,
-                               gmx::ArrayRef<const real>      nbfp,
-                               gmx::ArrayRef<const real>      nbfp_grid,
-                               gmx::ArrayRef<const real>      chargeA,
-                               gmx::ArrayRef<const real>      chargeB,
-                               gmx::ArrayRef<const int>       typeA,
-                               gmx::ArrayRef<const int>       typeB,
-                               int                            flags,
-                               gmx::ArrayRef<const real>      lambda,
-                               gmx::ArrayRef<real>            dvdl,
-                               gmx::ArrayRef<real>            energygrp_elec,
-                               gmx::ArrayRef<real>            energygrp_vdw,
-                               t_nrnb*                        nrnb)
+void gmx_nb_free_energy_kernel(const t_nblist&                           nlist,
+                               const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
+                               const bool                                useSimd,
+                               const int                                 ntype,
+                               const real                                rlist,
+                               const interaction_const_t&                ic,
+                               gmx::ArrayRef<const gmx::RVec>            shiftvec,
+                               gmx::ArrayRef<const real>                 nbfp,
+                               gmx::ArrayRef<const real>                 nbfp_grid,
+                               gmx::ArrayRef<const real>                 chargeA,
+                               gmx::ArrayRef<const real>                 chargeB,
+                               gmx::ArrayRef<const int>                  typeA,
+                               gmx::ArrayRef<const int>                  typeB,
+                               int                                       flags,
+                               gmx::ArrayRef<const real>                 lambda,
+                               t_nrnb*                                   nrnb,
+                               gmx::RVec*                                threadForceBuffer,
+                               rvec*                                     threadForceShiftBuffer,
+                               gmx::ArrayRef<real>                       threadVc,
+                               gmx::ArrayRef<real>                       threadVv,
+                               gmx::ArrayRef<real>                       threadDvdl)
 {
     GMX_ASSERT(EEL_PME_EWALD(ic.eeltype) || ic.eeltype == CoulombInteractionType::Cut || EEL_RF(ic.eeltype),
                "Unsupported eeltype with free energy");
@@ -1050,7 +1158,6 @@ void gmx_nb_free_energy_kernel(const t_nblist&                nlist,
                                 ic);
     kernelFunc(nlist,
                coords,
-               ff,
                ntype,
                rlist,
                ic,
@@ -1063,8 +1170,10 @@ void gmx_nb_free_energy_kernel(const t_nblist&                nlist,
                typeB,
                flags,
                lambda,
-               dvdl,
-               energygrp_elec,
-               energygrp_vdw,
-               nrnb);
+               nrnb,
+               threadForceBuffer,
+               threadForceShiftBuffer,
+               threadVc,
+               threadVv,
+               threadDvdl);
 }
index 3a04f196cb8699f2e1151950b80aef1f9787fe3e..d5c3188067a39045a1e145a98b69cd2571d6c0db 100644 (file)
@@ -47,30 +47,32 @@ struct t_nblist;
 struct interaction_const_t;
 namespace gmx
 {
-class ForceWithShiftForces;
 template<typename>
 class ArrayRef;
+template<typename>
+class ArrayRefWithPadding;
 } // namespace gmx
 
-void gmx_nb_free_energy_kernel(const t_nblist&                nlist,
-                               gmx::ArrayRef<const gmx::RVec> coords,
-                               gmx::ForceWithShiftForces*     forceWithShiftForces,
-                               bool                           useSimd,
-                               int                            ntype,
-                               real                           rlist,
-                               const interaction_const_t&     ic,
-                               gmx::ArrayRef<const gmx::RVec> shiftvec,
-                               gmx::ArrayRef<const real>      nbfp,
-                               gmx::ArrayRef<const real>      nbfp_grid,
-                               gmx::ArrayRef<const real>      chargeA,
-                               gmx::ArrayRef<const real>      chargeB,
-                               gmx::ArrayRef<const int>       typeA,
-                               gmx::ArrayRef<const int>       typeB,
-                               int                            flags,
-                               gmx::ArrayRef<const real>      lambda,
-                               gmx::ArrayRef<real>            dvdl,
-                               gmx::ArrayRef<real>            energygrp_elec,
-                               gmx::ArrayRef<real>            energygrp_vdw,
-                               t_nrnb* gmx_restrict           nrnb);
+void gmx_nb_free_energy_kernel(const t_nblist&                           nlist,
+                               const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
+                               bool                                      useSimd,
+                               int                                       ntype,
+                               real                                      rlist,
+                               const interaction_const_t&                ic,
+                               gmx::ArrayRef<const gmx::RVec>            shiftvec,
+                               gmx::ArrayRef<const real>                 nbfp,
+                               gmx::ArrayRef<const real>                 nbfp_grid,
+                               gmx::ArrayRef<const real>                 chargeA,
+                               gmx::ArrayRef<const real>                 chargeB,
+                               gmx::ArrayRef<const int>                  typeA,
+                               gmx::ArrayRef<const int>                  typeB,
+                               int                                       flags,
+                               gmx::ArrayRef<const real>                 lambda,
+                               t_nrnb* gmx_restrict                      nrnb,
+                               gmx::RVec*                                threadForceBuffer,
+                               rvec*                                     threadForceShiftBuffer,
+                               gmx::ArrayRef<real>                       threadVc,
+                               gmx::ArrayRef<real>                       threadVv,
+                               gmx::ArrayRef<real>                       threadDvdl);
 
 #endif
index 3ffc4d9ab131f5d5b4a1a51c3994dd4ed09b9d05..d25b98c1ddb703ee66890dfa03ee992eb56aa633 100644 (file)
@@ -416,7 +416,7 @@ protected:
         // When the free-energy kernel switches from tabulated to analytical corrections,
         // the double precision tolerance can be tightend to 1e-11.
         test::FloatingPointTolerance tolerance(
-                input_.floatToler, input_.doubleToler, 1.0e-6, 1.0e-6, 10000, 100, false);
+                input_.floatToler, input_.doubleToler, 1.0e-6, 1.0e-11, 10000, 100, false);
         checker_.setDefaultTolerance(tolerance);
     }
 
@@ -454,8 +454,7 @@ protected:
 
         // run fep kernel
         gmx_nb_free_energy_kernel(nbl,
-                                  x_.arrayRefWithPadding().unpaddedArrayRef(),
-                                  &forces,
+                                  x_.arrayRefWithPadding(),
                                   fr.use_simd_kernels,
                                   fr.ntype,
                                   fr.rlist,
@@ -469,10 +468,12 @@ protected:
                                   input_.atoms.typeB,
                                   doNBFlags,
                                   lambdas,
-                                  output.dvdLambda,
+                                  &nrnb,
+                                  output.f.arrayRefWithPadding().paddedArrayRef().data(),
+                                  as_rvec_array(output.fShift.data()),
                                   output.energy.energyGroupPairTerms[NonBondedEnergyTerms::CoulombSR],
                                   output.energy.energyGroupPairTerms[NonBondedEnergyTerms::LJSR],
-                                  &nrnb);
+                                  output.dvdLambda);
 
         checkOutput(&checker_, output);
     }
index 13a13d23eac9f6b8921ee152fcbc57e162c148f5..1d703a72458e68081e9883c6e8e57222aa1da602 100644 (file)
@@ -1634,6 +1634,15 @@ void do_force(FILE*                               fplog,
         }
     }
 
+    // With FEP we set up the reduction over threads for local+non-local simultaneously,
+    // so we need to do that here after the local and non-local pairlist construction.
+    if (stepWork.doNeighborSearch && fr->efep != FreeEnergyPerturbationType::No)
+    {
+        wallcycle_sub_start(wcycle, WallCycleSubCounter::NonbondedFep);
+        nbv->setupFepThreadedForceBuffer(fr->natoms_force_constr);
+        wallcycle_sub_stop(wcycle, WallCycleSubCounter::NonbondedFep);
+    }
+
     if (simulationWork.useGpuNonbonded && stepWork.computeNonbondedForces)
     {
         /* launch D2H copy-back F */
@@ -1776,9 +1785,8 @@ void do_force(FILE*                               fplog,
         /* Calculate the local and non-local free energy interactions here.
          * Happens here on the CPU both with and without GPU.
          */
-        nbv->dispatchFreeEnergyKernel(
-                InteractionLocality::Local,
-                x.unpaddedArrayRef(),
+        nbv->dispatchFreeEnergyKernels(
+                x,
                 &forceOutNonbonded->forceWithShiftForces(),
                 fr->use_simd_kernels,
                 fr->ntype,
@@ -1800,34 +1808,6 @@ void do_force(FILE*                               fplog,
                 enerd,
                 stepWork,
                 nrnb);
-
-        if (simulationWork.havePpDomainDecomposition)
-        {
-            nbv->dispatchFreeEnergyKernel(
-                    InteractionLocality::NonLocal,
-                    x.unpaddedArrayRef(),
-                    &forceOutNonbonded->forceWithShiftForces(),
-                    fr->use_simd_kernels,
-                    fr->ntype,
-                    fr->rlist,
-                    *fr->ic,
-                    fr->shift_vec,
-                    fr->nbfp,
-                    fr->ljpme_c6grid,
-                    mdatoms->chargeA ? gmx::arrayRefFromArray(mdatoms->chargeA, mdatoms->nr)
-                                     : gmx::ArrayRef<real>{},
-                    mdatoms->chargeB ? gmx::arrayRefFromArray(mdatoms->chargeB, mdatoms->nr)
-                                     : gmx::ArrayRef<real>{},
-                    mdatoms->typeA ? gmx::arrayRefFromArray(mdatoms->typeA, mdatoms->nr)
-                                   : gmx::ArrayRef<int>{},
-                    mdatoms->typeB ? gmx::arrayRefFromArray(mdatoms->typeB, mdatoms->nr)
-                                   : gmx::ArrayRef<int>{},
-                    inputrec.fepvals.get(),
-                    lambda,
-                    enerd,
-                    stepWork,
-                    nrnb);
-        }
     }
 
     if (stepWork.computeNonbondedForces && !useOrEmulateGpuNb)
index d95b2c5772e26a6df39b6b5098e6c23b3127a5d4..ca5d09280278593066f79536af180bc2808a34bc 100644 (file)
@@ -38,9 +38,10 @@ add_library(nbnxm INTERFACE)
 add_subdirectory(kernels_simd_4xm)
 add_subdirectory(kernels_simd_2xmm)
 
-file (GLOB NBNXM_SOURCES
+file(GLOB NBNXM_SOURCES
     # Source files
     atomdata.cpp
+    freeenergydispatch.cpp
     grid.cpp
     gridset.cpp
     kernel_common.cpp
diff --git a/src/gromacs/nbnxm/freeenergydispatch.cpp b/src/gromacs/nbnxm/freeenergydispatch.cpp
new file mode 100644 (file)
index 0000000..826bd9f
--- /dev/null
@@ -0,0 +1,462 @@
+/*
+ * This file is part of the GROMACS molecular simulation package.
+ *
+ * Copyright (c) 2021, by the GROMACS development team, led by
+ * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
+ * and including many others, as listed in the AUTHORS file in the
+ * top-level source directory and at http://www.gromacs.org.
+ *
+ * GROMACS is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public License
+ * as published by the Free Software Foundation; either version 2.1
+ * of the License, or (at your option) any later version.
+ *
+ * GROMACS is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with GROMACS; if not, see
+ * http://www.gnu.org/licenses, or write to the Free Software Foundation,
+ * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
+ *
+ * If you want to redistribute modifications to GROMACS, please
+ * consider that scientific software is very special. Version
+ * control is crucial - bugs must be traceable. We will be happy to
+ * consider code for inclusion in the official distribution, but
+ * derived work must not be called official GROMACS. Details are found
+ * in the README & COPYING files - if they are missing, get the
+ * official version at http://www.gromacs.org.
+ *
+ * To help us fund GROMACS development, we humbly ask that you cite
+ * the research papers on the package. Check out http://www.gromacs.org.
+ */
+
+#include "gmxpre.h"
+
+#include "freeenergydispatch.h"
+
+#include "gromacs/gmxlib/nrnb.h"
+#include "gromacs/gmxlib/nonbonded/nb_free_energy.h"
+#include "gromacs/gmxlib/nonbonded/nonbonded.h"
+#include "gromacs/math/vectypes.h"
+#include "gromacs/mdlib/enerdata_utils.h"
+#include "gromacs/mdlib/force.h"
+#include "gromacs/mdlib/gmx_omp_nthreads.h"
+#include "gromacs/mdtypes/enerdata.h"
+#include "gromacs/mdtypes/forceoutput.h"
+#include "gromacs/mdtypes/inputrec.h"
+#include "gromacs/mdtypes/interaction_const.h"
+#include "gromacs/mdtypes/md_enums.h"
+#include "gromacs/mdtypes/nblist.h"
+#include "gromacs/mdtypes/simulation_workload.h"
+#include "gromacs/mdtypes/threaded_force_buffer.h"
+#include "gromacs/nbnxm/nbnxm.h"
+#include "gromacs/timing/wallcycle.h"
+#include "gromacs/utility/enumerationhelpers.h"
+#include "gromacs/utility/gmxassert.h"
+#include "gromacs/utility/real.h"
+
+#include "pairlistset.h"
+#include "pairlistsets.h"
+
+FreeEnergyDispatch::FreeEnergyDispatch(const int numEnergyGroups) :
+    foreignGroupPairEnergies_(numEnergyGroups),
+    threadedForceBuffer_(gmx_omp_nthreads_get(ModuleMultiThread::Nonbonded), false, numEnergyGroups),
+    threadedForeignEnergyBuffer_(gmx_omp_nthreads_get(ModuleMultiThread::Nonbonded), false, numEnergyGroups)
+{
+}
+
+namespace
+{
+
+//! Flags all atoms present in pairlist \p nlist in the mask in \p threadForceBuffer
+void setReductionMaskFromFepPairlist(const t_nblist& gmx_restrict       nlist,
+                                     gmx::ThreadForceBuffer<gmx::RVec>* threadForceBuffer)
+{
+    // Extract pair list data
+    gmx::ArrayRef<const int> iinr = nlist.iinr;
+    gmx::ArrayRef<const int> jjnr = nlist.jjnr;
+
+    for (int i : iinr)
+    {
+        threadForceBuffer->addAtomToMask(i);
+    }
+    for (int j : jjnr)
+    {
+        threadForceBuffer->addAtomToMask(j);
+    }
+}
+
+} // namespace
+
+void FreeEnergyDispatch::setupFepThreadedForceBuffer(const int numAtomsForce, const PairlistSets& pairlistSets)
+{
+    const int numThreads = threadedForceBuffer_.numThreadBuffers();
+
+    GMX_ASSERT(gmx_omp_nthreads_get(ModuleMultiThread::Nonbonded) == numThreads,
+               "The number of buffers should be same as number of NB threads");
+
+#pragma omp parallel for num_threads(numThreads) schedule(static)
+    for (int th = 0; th < numThreads; th++)
+    {
+        auto& threadForceBuffer = threadedForceBuffer_.threadForceBuffer(th);
+
+        threadForceBuffer.resizeBufferAndClearMask(numAtomsForce);
+
+        setReductionMaskFromFepPairlist(
+                *pairlistSets.pairlistSet(gmx::InteractionLocality::Local).fepLists()[th],
+                &threadForceBuffer);
+        if (pairlistSets.params().haveMultipleDomains)
+        {
+            setReductionMaskFromFepPairlist(
+                    *pairlistSets.pairlistSet(gmx::InteractionLocality::NonLocal).fepLists()[th],
+                    &threadForceBuffer);
+        }
+
+        threadForceBuffer.processMask();
+    }
+
+    threadedForceBuffer_.setupReduction();
+}
+
+void nonbonded_verlet_t::setupFepThreadedForceBuffer(const int numAtomsForce)
+{
+    if (!pairlistSets_->params().haveFep)
+    {
+        return;
+    }
+
+    GMX_RELEASE_ASSERT(freeEnergyDispatch_, "Need a valid dispatch object");
+
+    freeEnergyDispatch_->setupFepThreadedForceBuffer(numAtomsForce, *pairlistSets_);
+}
+
+namespace
+{
+
+void dispatchFreeEnergyKernel(gmx::ArrayRef<const std::unique_ptr<t_nblist>>   nbl_fep,
+                              const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
+                              bool                                             useSimd,
+                              int                                              ntype,
+                              real                                             rlist,
+                              const interaction_const_t&                       ic,
+                              gmx::ArrayRef<const gmx::RVec>                   shiftvec,
+                              gmx::ArrayRef<const real>                        nbfp,
+                              gmx::ArrayRef<const real>                        nbfp_grid,
+                              gmx::ArrayRef<const real>                        chargeA,
+                              gmx::ArrayRef<const real>                        chargeB,
+                              gmx::ArrayRef<const int>                         typeA,
+                              gmx::ArrayRef<const int>                         typeB,
+                              t_lambda*                                        fepvals,
+                              gmx::ArrayRef<const real>                        lambda,
+                              const bool                           clearForcesAndEnergies,
+                              gmx::ThreadedForceBuffer<gmx::RVec>* threadedForceBuffer,
+                              gmx::ThreadedForceBuffer<gmx::RVec>* threadedForeignEnergyBuffer,
+                              gmx_grppairener_t*                   foreignGroupPairEnergies,
+                              gmx_enerdata_t*                      enerd,
+                              const gmx::StepWorkload&             stepWork,
+                              t_nrnb*                              nrnb)
+{
+    int donb_flags = 0;
+    /* Add short-range interactions */
+    donb_flags |= GMX_NONBONDED_DO_SR;
+
+    if (stepWork.computeForces)
+    {
+        donb_flags |= GMX_NONBONDED_DO_FORCE;
+    }
+    if (stepWork.computeVirial)
+    {
+        donb_flags |= GMX_NONBONDED_DO_SHIFTFORCE;
+    }
+    if (stepWork.computeEnergy)
+    {
+        donb_flags |= GMX_NONBONDED_DO_POTENTIAL;
+    }
+
+    GMX_ASSERT(gmx_omp_nthreads_get(ModuleMultiThread::Nonbonded) == nbl_fep.ssize(),
+               "Number of lists should be same as number of NB threads");
+
+#pragma omp parallel for schedule(static) num_threads(nbl_fep.ssize())
+    for (gmx::index th = 0; th < nbl_fep.ssize(); th++)
+    {
+        try
+        {
+            auto& threadForceBuffer = threadedForceBuffer->threadForceBuffer(th);
+
+            if (clearForcesAndEnergies)
+            {
+                threadForceBuffer.clearForcesAndEnergies();
+            }
+
+            gmx::RVec* threadForces      = threadForceBuffer.forceBuffer();
+            rvec* threadForceShiftBuffer = as_rvec_array(threadForceBuffer.shiftForces().data());
+            gmx::ArrayRef<real> threadVc =
+                    threadForceBuffer.groupPairEnergies().energyGroupPairTerms[NonBondedEnergyTerms::CoulombSR];
+            gmx::ArrayRef<real> threadVv =
+                    threadForceBuffer.groupPairEnergies().energyGroupPairTerms[NonBondedEnergyTerms::LJSR];
+            gmx::ArrayRef<real> threadDvdl = threadForceBuffer.dvdl();
+
+            gmx_nb_free_energy_kernel(*nbl_fep[th],
+                                      coords,
+                                      useSimd,
+                                      ntype,
+                                      rlist,
+                                      ic,
+                                      shiftvec,
+                                      nbfp,
+                                      nbfp_grid,
+                                      chargeA,
+                                      chargeB,
+                                      typeA,
+                                      typeB,
+                                      donb_flags,
+                                      lambda,
+                                      nrnb,
+                                      threadForces,
+                                      threadForceShiftBuffer,
+                                      threadVc,
+                                      threadVv,
+                                      threadDvdl);
+        }
+        GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
+    }
+
+    /* If we do foreign lambda and we have soft-core interactions
+     * we have to recalculate the (non-linear) energies contributions.
+     */
+    if (fepvals->n_lambda > 0 && stepWork.computeDhdl && fepvals->sc_alpha != 0)
+    {
+        gmx::StepWorkload stepWorkForeignEnergies = stepWork;
+        stepWorkForeignEnergies.computeForces     = false;
+        stepWorkForeignEnergies.computeVirial     = false;
+
+        gmx::EnumerationArray<FreeEnergyPerturbationCouplingType, real> lam_i;
+        gmx::EnumerationArray<FreeEnergyPerturbationCouplingType, real> dvdl_nb = { 0 };
+        const int kernelFlags = (donb_flags & ~(GMX_NONBONDED_DO_FORCE | GMX_NONBONDED_DO_SHIFTFORCE))
+                                | GMX_NONBONDED_DO_FOREIGNLAMBDA;
+
+        for (gmx::index i = 0; i < 1 + enerd->foreignLambdaTerms.numLambdas(); i++)
+        {
+            std::fill(std::begin(dvdl_nb), std::end(dvdl_nb), 0);
+            for (int j = 0; j < static_cast<int>(FreeEnergyPerturbationCouplingType::Count); j++)
+            {
+                lam_i[j] = (i == 0 ? lambda[j] : fepvals->all_lambda[j][i - 1]);
+            }
+
+#pragma omp parallel for schedule(static) num_threads(nbl_fep.ssize())
+            for (gmx::index th = 0; th < nbl_fep.ssize(); th++)
+            {
+                try
+                {
+                    // Note that here we only compute energies and dV/dlambda, but we need
+                    // to pass a force buffer. No forces are compute and stored.
+                    auto& threadForeignEnergyBuffer = threadedForeignEnergyBuffer->threadForceBuffer(th);
+
+                    threadForeignEnergyBuffer.clearForcesAndEnergies();
+
+                    gmx::ArrayRef<real> threadVc =
+                            threadForeignEnergyBuffer.groupPairEnergies()
+                                    .energyGroupPairTerms[NonBondedEnergyTerms::CoulombSR];
+                    gmx::ArrayRef<real> threadVv =
+                            threadForeignEnergyBuffer.groupPairEnergies()
+                                    .energyGroupPairTerms[NonBondedEnergyTerms::LJSR];
+                    gmx::ArrayRef<real> threadDvdl = threadForeignEnergyBuffer.dvdl();
+
+                    gmx_nb_free_energy_kernel(*nbl_fep[th],
+                                              coords,
+                                              useSimd,
+                                              ntype,
+                                              rlist,
+                                              ic,
+                                              shiftvec,
+                                              nbfp,
+                                              nbfp_grid,
+                                              chargeA,
+                                              chargeB,
+                                              typeA,
+                                              typeB,
+                                              kernelFlags,
+                                              lam_i,
+                                              nrnb,
+                                              nullptr,
+                                              nullptr,
+                                              threadVc,
+                                              threadVv,
+                                              threadDvdl);
+                }
+                GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
+            }
+
+            foreignGroupPairEnergies->clear();
+            threadedForeignEnergyBuffer->reduce(
+                    nullptr, nullptr, foreignGroupPairEnergies, dvdl_nb, stepWorkForeignEnergies, 0);
+
+            std::array<real, F_NRE> foreign_term = { 0 };
+            sum_epot(*foreignGroupPairEnergies, foreign_term.data());
+            // Accumulate the foreign energy difference and dV/dlambda into the passed enerd
+            enerd->foreignLambdaTerms.accumulate(
+                    i,
+                    foreign_term[F_EPOT],
+                    dvdl_nb[FreeEnergyPerturbationCouplingType::Vdw]
+                            + dvdl_nb[FreeEnergyPerturbationCouplingType::Coul]);
+        }
+    }
+}
+
+} // namespace
+
+void FreeEnergyDispatch::dispatchFreeEnergyKernels(const PairlistSets& pairlistSets,
+                                                   const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
+                                                   gmx::ForceWithShiftForces* forceWithShiftForces,
+                                                   const bool                 useSimd,
+                                                   const int                  ntype,
+                                                   const real                 rlist,
+                                                   const interaction_const_t& ic,
+                                                   gmx::ArrayRef<const gmx::RVec> shiftvec,
+                                                   gmx::ArrayRef<const real>      nbfp,
+                                                   gmx::ArrayRef<const real>      nbfp_grid,
+                                                   gmx::ArrayRef<const real>      chargeA,
+                                                   gmx::ArrayRef<const real>      chargeB,
+                                                   gmx::ArrayRef<const int>       typeA,
+                                                   gmx::ArrayRef<const int>       typeB,
+                                                   t_lambda*                      fepvals,
+                                                   gmx::ArrayRef<const real>      lambda,
+                                                   gmx_enerdata_t*                enerd,
+                                                   const gmx::StepWorkload&       stepWork,
+                                                   t_nrnb*                        nrnb,
+                                                   gmx_wallcycle*                 wcycle)
+{
+    GMX_ASSERT(pairlistSets.params().haveFep, "We should have a free-energy pairlist");
+
+    wallcycle_sub_start(wcycle, WallCycleSubCounter::NonbondedFep);
+
+    const int numLocalities = (pairlistSets.params().haveMultipleDomains ? 2 : 1);
+    // The first call to dispatchFreeEnergyKernel() should clear the buffers. Clearing happens
+    // inside that function to avoid an extra OpenMP parallel region here. We need a boolean
+    // to track the need for clearing.
+    // A better solution would be to move the OpenMP parallel region here, but that first
+    // requires modifying ThreadedForceBuffer.reduce() to be called thread parallel.
+    bool clearForcesAndEnergies = true;
+    for (int i = 0; i < numLocalities; i++)
+    {
+        const gmx::InteractionLocality iLocality = static_cast<gmx::InteractionLocality>(i);
+        const auto fepPairlists                  = pairlistSets.pairlistSet(iLocality).fepLists();
+        /* When the first list is empty, all are empty and there is nothing to do */
+        if (fepPairlists[0]->nrj > 0)
+        {
+            dispatchFreeEnergyKernel(fepPairlists,
+                                     coords,
+                                     useSimd,
+                                     ntype,
+                                     rlist,
+                                     ic,
+                                     shiftvec,
+                                     nbfp,
+                                     nbfp_grid,
+                                     chargeA,
+                                     chargeB,
+                                     typeA,
+                                     typeB,
+                                     fepvals,
+                                     lambda,
+                                     clearForcesAndEnergies,
+                                     &threadedForceBuffer_,
+                                     &threadedForeignEnergyBuffer_,
+                                     &foreignGroupPairEnergies_,
+                                     enerd,
+                                     stepWork,
+                                     nrnb);
+        }
+        else if (clearForcesAndEnergies)
+        {
+            // We need to clear the thread force buffer.
+            // With a non-empty pairlist we do this in dispatchFreeEnergyKernel()
+            // to avoid the overhead of an extra openMP parallel loop
+#pragma omp parallel for schedule(static) num_threads(fepPairlists.ssize())
+            for (gmx::index th = 0; th < fepPairlists.ssize(); th++)
+            {
+                try
+                {
+                    threadedForceBuffer_.threadForceBuffer(th).clearForcesAndEnergies();
+                }
+                GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
+            }
+        }
+        clearForcesAndEnergies = false;
+    }
+    wallcycle_sub_stop(wcycle, WallCycleSubCounter::NonbondedFep);
+
+    wallcycle_sub_start(wcycle, WallCycleSubCounter::NonbondedFepReduction);
+
+    gmx::EnumerationArray<FreeEnergyPerturbationCouplingType, real> dvdl_nb = { 0 };
+
+    threadedForceBuffer_.reduce(forceWithShiftForces, nullptr, &enerd->grpp, dvdl_nb, stepWork, 0);
+
+    if (fepvals->sc_alpha != 0)
+    {
+        enerd->dvdl_nonlin[FreeEnergyPerturbationCouplingType::Vdw] +=
+                dvdl_nb[FreeEnergyPerturbationCouplingType::Vdw];
+        enerd->dvdl_nonlin[FreeEnergyPerturbationCouplingType::Coul] +=
+                dvdl_nb[FreeEnergyPerturbationCouplingType::Coul];
+    }
+    else
+    {
+        enerd->dvdl_lin[FreeEnergyPerturbationCouplingType::Vdw] +=
+                dvdl_nb[FreeEnergyPerturbationCouplingType::Vdw];
+        enerd->dvdl_lin[FreeEnergyPerturbationCouplingType::Coul] +=
+                dvdl_nb[FreeEnergyPerturbationCouplingType::Coul];
+    }
+
+    wallcycle_sub_stop(wcycle, WallCycleSubCounter::NonbondedFepReduction);
+}
+
+void nonbonded_verlet_t::dispatchFreeEnergyKernels(const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
+                                                   gmx::ForceWithShiftForces* forceWithShiftForces,
+                                                   const bool                 useSimd,
+                                                   const int                  ntype,
+                                                   const real                 rlist,
+                                                   const interaction_const_t& ic,
+                                                   gmx::ArrayRef<const gmx::RVec> shiftvec,
+                                                   gmx::ArrayRef<const real>      nbfp,
+                                                   gmx::ArrayRef<const real>      nbfp_grid,
+                                                   gmx::ArrayRef<const real>      chargeA,
+                                                   gmx::ArrayRef<const real>      chargeB,
+                                                   gmx::ArrayRef<const int>       typeA,
+                                                   gmx::ArrayRef<const int>       typeB,
+                                                   t_lambda*                      fepvals,
+                                                   gmx::ArrayRef<const real>      lambda,
+                                                   gmx_enerdata_t*                enerd,
+                                                   const gmx::StepWorkload&       stepWork,
+                                                   t_nrnb*                        nrnb)
+{
+    if (!pairlistSets_->params().haveFep)
+    {
+        return;
+    }
+
+    GMX_RELEASE_ASSERT(freeEnergyDispatch_, "Need a valid dispatch object");
+
+    freeEnergyDispatch_->dispatchFreeEnergyKernels(*pairlistSets_,
+                                                   coords,
+                                                   forceWithShiftForces,
+                                                   useSimd,
+                                                   ntype,
+                                                   rlist,
+                                                   ic,
+                                                   shiftvec,
+                                                   nbfp,
+                                                   nbfp_grid,
+                                                   chargeA,
+                                                   chargeB,
+                                                   typeA,
+                                                   typeB,
+                                                   fepvals,
+                                                   lambda,
+                                                   enerd,
+                                                   stepWork,
+                                                   nrnb,
+                                                   wcycle_);
+}
diff --git a/src/gromacs/nbnxm/freeenergydispatch.h b/src/gromacs/nbnxm/freeenergydispatch.h
new file mode 100644 (file)
index 0000000..dc09a72
--- /dev/null
@@ -0,0 +1,112 @@
+/*
+ * This file is part of the GROMACS molecular simulation package.
+ *
+ * Copyright (c) 2021, by the GROMACS development team, led by
+ * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
+ * and including many others, as listed in the AUTHORS file in the
+ * top-level source directory and at http://www.gromacs.org.
+ *
+ * GROMACS is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public License
+ * as published by the Free Software Foundation; either version 2.1
+ * of the License, or (at your option) any later version.
+ *
+ * GROMACS is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with GROMACS; if not, see
+ * http://www.gnu.org/licenses, or write to the Free Software Foundation,
+ * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
+ *
+ * If you want to redistribute modifications to GROMACS, please
+ * consider that scientific software is very special. Version
+ * control is crucial - bugs must be traceable. We will be happy to
+ * consider code for inclusion in the official distribution, but
+ * derived work must not be called official GROMACS. Details are found
+ * in the README & COPYING files - if they are missing, get the
+ * official version at http://www.gromacs.org.
+ *
+ * To help us fund GROMACS development, we humbly ask that you cite
+ * the research papers on the package. Check out http://www.gromacs.org.
+ */
+/*! \internal \file
+ *
+ * \brief
+ * Declares the free-energy kernel dispatch class
+ *
+ * \author Berk Hess <hess@kth.se>
+ * \ingroup module_nbnxm
+ */
+#ifndef GMX_NBNXM_FREEENERGYDISPATCH_H
+#define GMX_NBNXM_FREEENERGYDISPATCH_H
+
+#include <memory>
+
+#include "gromacs/math/vectypes.h"
+#include "gromacs/mdtypes/enerdata.h"
+#include "gromacs/mdtypes/threaded_force_buffer.h"
+#include "gromacs/utility/arrayref.h"
+
+struct gmx_enerdata_t;
+struct gmx_wallcycle;
+struct interaction_const_t;
+class PairlistSets;
+struct t_lambda;
+struct t_nrnb;
+
+namespace gmx
+{
+template<typename>
+class ArrayRefWithPadding;
+class ForceWithShiftForces;
+class StepWorkload;
+} // namespace gmx
+
+/*! \libinternal
+ *  \brief Temporary data and methods for handling dispatching of the nbnxm free-energy kernels
+ */
+class FreeEnergyDispatch
+{
+public:
+    //! Constructor
+    FreeEnergyDispatch(int numEnergyGroups);
+
+    //! Sets up the threaded force buffer and the reduction, should be called after constructing the pair lists
+    void setupFepThreadedForceBuffer(int numAtomsForce, const PairlistSets& pairlistSets);
+
+    //! Dispatches the non-bonded free-energy kernels, thread parallel and reduces the output
+    void dispatchFreeEnergyKernels(const PairlistSets&                       pairlistSets,
+                                   const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
+                                   gmx::ForceWithShiftForces*                forceWithShiftForces,
+                                   bool                                      useSimd,
+                                   int                                       ntype,
+                                   real                                      rlist,
+                                   const interaction_const_t&                ic,
+                                   gmx::ArrayRef<const gmx::RVec>            shiftvec,
+                                   gmx::ArrayRef<const real>                 nbfp,
+                                   gmx::ArrayRef<const real>                 nbfp_grid,
+                                   gmx::ArrayRef<const real>                 chargeA,
+                                   gmx::ArrayRef<const real>                 chargeB,
+                                   gmx::ArrayRef<const int>                  typeA,
+                                   gmx::ArrayRef<const int>                  typeB,
+                                   t_lambda*                                 fepvals,
+                                   gmx::ArrayRef<const real>                 lambda,
+                                   gmx_enerdata_t*                           enerd,
+                                   const gmx::StepWorkload&                  stepWork,
+                                   t_nrnb*                                   nrnb,
+                                   gmx_wallcycle*                            wcycle);
+
+private:
+    //! Temporary array for storing foreign lambda group pair energies
+    gmx_grppairener_t foreignGroupPairEnergies_;
+
+    //! Threaded force buffer for nonbonded FEP
+    gmx::ThreadedForceBuffer<gmx::RVec> threadedForceBuffer_;
+    //! Threaded buffer for nonbonded FEP foreign energies and dVdl, no forces, so numAtoms = 0
+    gmx::ThreadedForceBuffer<gmx::RVec> threadedForeignEnergyBuffer_;
+};
+
+#endif // GMX_NBNXN_FREEENERGYDISPATCH_H
index f7229169e3f30ad76794730f275cd1198f54e2f5..be105c93ee9eecfd2e3b1e4afd74c5724637d8ca 100644 (file)
@@ -37,8 +37,6 @@
 #include "gmxpre.h"
 
 #include "gromacs/gmxlib/nrnb.h"
-#include "gromacs/gmxlib/nonbonded/nb_free_energy.h"
-#include "gromacs/gmxlib/nonbonded/nonbonded.h"
 #include "gromacs/math/vectypes.h"
 #include "gromacs/mdlib/enerdata_utils.h"
 #include "gromacs/mdlib/force.h"
@@ -497,167 +495,3 @@ void nonbonded_verlet_t::dispatchNonbondedKernel(gmx::InteractionLocality
 
     accountFlops(nrnb, pairlistSet, *this, ic, stepWork);
 }
-
-void nonbonded_verlet_t::dispatchFreeEnergyKernel(gmx::InteractionLocality       iLocality,
-                                                  gmx::ArrayRef<const gmx::RVec> coords,
-                                                  gmx::ForceWithShiftForces* forceWithShiftForces,
-                                                  bool                       useSimd,
-                                                  int                        ntype,
-                                                  real                       rlist,
-                                                  const interaction_const_t& ic,
-                                                  gmx::ArrayRef<const gmx::RVec> shiftvec,
-                                                  gmx::ArrayRef<const real>      nbfp,
-                                                  gmx::ArrayRef<const real>      nbfp_grid,
-                                                  gmx::ArrayRef<const real>      chargeA,
-                                                  gmx::ArrayRef<const real>      chargeB,
-                                                  gmx::ArrayRef<const int>       typeA,
-                                                  gmx::ArrayRef<const int>       typeB,
-                                                  t_lambda*                      fepvals,
-                                                  gmx::ArrayRef<const real>      lambda,
-                                                  gmx_enerdata_t*                enerd,
-                                                  const gmx::StepWorkload&       stepWork,
-                                                  t_nrnb*                        nrnb)
-{
-    const auto nbl_fep = pairlistSets().pairlistSet(iLocality).fepLists();
-
-    /* When the first list is empty, all are empty and there is nothing to do */
-    if (!pairlistSets().params().haveFep || nbl_fep[0]->nrj == 0)
-    {
-        return;
-    }
-
-    int donb_flags = 0;
-    /* Add short-range interactions */
-    donb_flags |= GMX_NONBONDED_DO_SR;
-
-    if (stepWork.computeForces)
-    {
-        donb_flags |= GMX_NONBONDED_DO_FORCE;
-    }
-    if (stepWork.computeVirial)
-    {
-        donb_flags |= GMX_NONBONDED_DO_SHIFTFORCE;
-    }
-    if (stepWork.computeEnergy)
-    {
-        donb_flags |= GMX_NONBONDED_DO_POTENTIAL;
-    }
-
-    gmx::EnumerationArray<FreeEnergyPerturbationCouplingType, real> dvdl_nb      = { 0 };
-    int                                                             kernelFlags  = donb_flags;
-    gmx::ArrayRef<const real>                                       kernelLambda = lambda;
-    gmx::ArrayRef<real>                                             kernelDvdl   = dvdl_nb;
-
-    gmx::ArrayRef<real> energygrp_elec = enerd->grpp.energyGroupPairTerms[NonBondedEnergyTerms::CoulombSR];
-    gmx::ArrayRef<real> energygrp_vdw = enerd->grpp.energyGroupPairTerms[NonBondedEnergyTerms::LJSR];
-
-    GMX_ASSERT(gmx_omp_nthreads_get(ModuleMultiThread::Nonbonded) == nbl_fep.ssize(),
-               "Number of lists should be same as number of NB threads");
-
-    wallcycle_sub_start(wcycle_, WallCycleSubCounter::NonbondedFep);
-#pragma omp parallel for schedule(static) num_threads(nbl_fep.ssize())
-    for (gmx::index th = 0; th < nbl_fep.ssize(); th++)
-    {
-        try
-        {
-            gmx_nb_free_energy_kernel(*nbl_fep[th],
-                                      coords,
-                                      forceWithShiftForces,
-                                      useSimd,
-                                      ntype,
-                                      rlist,
-                                      ic,
-                                      shiftvec,
-                                      nbfp,
-                                      nbfp_grid,
-                                      chargeA,
-                                      chargeB,
-                                      typeA,
-                                      typeB,
-                                      kernelFlags,
-                                      kernelLambda,
-                                      kernelDvdl,
-                                      energygrp_elec,
-                                      energygrp_vdw,
-                                      nrnb);
-        }
-        GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
-    }
-
-    if (fepvals->sc_alpha != 0)
-    {
-        enerd->dvdl_nonlin[FreeEnergyPerturbationCouplingType::Vdw] +=
-                dvdl_nb[FreeEnergyPerturbationCouplingType::Vdw];
-        enerd->dvdl_nonlin[FreeEnergyPerturbationCouplingType::Coul] +=
-                dvdl_nb[FreeEnergyPerturbationCouplingType::Coul];
-    }
-    else
-    {
-        enerd->dvdl_lin[FreeEnergyPerturbationCouplingType::Vdw] +=
-                dvdl_nb[FreeEnergyPerturbationCouplingType::Vdw];
-        enerd->dvdl_lin[FreeEnergyPerturbationCouplingType::Coul] +=
-                dvdl_nb[FreeEnergyPerturbationCouplingType::Coul];
-    }
-
-    /* If we do foreign lambda and we have soft-core interactions
-     * we have to recalculate the (non-linear) energies contributions.
-     */
-    if (fepvals->n_lambda > 0 && stepWork.computeDhdl && fepvals->sc_alpha != 0)
-    {
-        gmx::EnumerationArray<FreeEnergyPerturbationCouplingType, real> lam_i;
-        kernelFlags = (donb_flags & ~(GMX_NONBONDED_DO_FORCE | GMX_NONBONDED_DO_SHIFTFORCE))
-                      | GMX_NONBONDED_DO_FOREIGNLAMBDA;
-        kernelLambda = lam_i;
-        kernelDvdl   = dvdl_nb;
-        gmx::ArrayRef<real> energygrp_elec =
-                foreignEnergyGroups_->energyGroupPairTerms[NonBondedEnergyTerms::CoulombSR];
-        gmx::ArrayRef<real> energygrp_vdw =
-                foreignEnergyGroups_->energyGroupPairTerms[NonBondedEnergyTerms::LJSR];
-
-        for (gmx::index i = 0; i < 1 + enerd->foreignLambdaTerms.numLambdas(); i++)
-        {
-            std::fill(std::begin(dvdl_nb), std::end(dvdl_nb), 0);
-            for (int j = 0; j < static_cast<int>(FreeEnergyPerturbationCouplingType::Count); j++)
-            {
-                lam_i[j] = (i == 0 ? lambda[j] : fepvals->all_lambda[j][i - 1]);
-            }
-            foreignEnergyGroups_->clear();
-#pragma omp parallel for schedule(static) num_threads(nbl_fep.ssize())
-            for (gmx::index th = 0; th < nbl_fep.ssize(); th++)
-            {
-                try
-                {
-                    gmx_nb_free_energy_kernel(*nbl_fep[th],
-                                              coords,
-                                              forceWithShiftForces,
-                                              useSimd,
-                                              ntype,
-                                              rlist,
-                                              ic,
-                                              shiftvec,
-                                              nbfp,
-                                              nbfp_grid,
-                                              chargeA,
-                                              chargeB,
-                                              typeA,
-                                              typeB,
-                                              kernelFlags,
-                                              kernelLambda,
-                                              kernelDvdl,
-                                              energygrp_elec,
-                                              energygrp_vdw,
-                                              nrnb);
-                }
-                GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
-            }
-            std::array<real, F_NRE> foreign_term = { 0 };
-            sum_epot(*foreignEnergyGroups_, foreign_term.data());
-            enerd->foreignLambdaTerms.accumulate(
-                    i,
-                    foreign_term[F_EPOT],
-                    dvdl_nb[FreeEnergyPerturbationCouplingType::Vdw]
-                            + dvdl_nb[FreeEnergyPerturbationCouplingType::Coul]);
-        }
-    }
-    wallcycle_sub_stop(wcycle_, WallCycleSubCounter::NonbondedFep);
-}
index dca77313ed9168c3dd48a1ca37dc4d8e1a4018ff..f634f75fac7b0ed8eb8c31715c50af1e060b2c79 100644 (file)
@@ -258,5 +258,4 @@ bool buildSupportsNonbondedOnGpu(std::string* error)
     return errorReasons.isEmpty();
 }
 
-
 /*! \endcond */
index e3e90d9238115710d8048f5c61cba340c8cc0a58..b438c9bc332babd3581949d1b6909c35c042a18f 100644 (file)
 #include "gromacs/utility/real.h"
 
 struct DeviceInformation;
+class FreeEnergyDispatch;
 struct gmx_domdec_zones_t;
 struct gmx_enerdata_t;
 struct gmx_hw_info_t;
@@ -144,6 +145,8 @@ class GpuEventSynchronizer;
 
 namespace gmx
 {
+template<typename>
+class ArrayRefWithPadding;
 class DeviceStreamManager;
 class ForceWithShiftForces;
 class ListedForcesGpu;
@@ -368,26 +371,25 @@ public:
                                  gmx::ArrayRef<real>            CoulombSR,
                                  t_nrnb*                        nrnb) const;
 
-    //! Executes the non-bonded free-energy kernel, always runs on the CPU
-    void dispatchFreeEnergyKernel(gmx::InteractionLocality       iLocality,
-                                  gmx::ArrayRef<const gmx::RVec> coords,
-                                  gmx::ForceWithShiftForces*     forceWithShiftForces,
-                                  bool                           useSimd,
-                                  int                            ntype,
-                                  real                           rlist,
-                                  const interaction_const_t&     ic,
-                                  gmx::ArrayRef<const gmx::RVec> shiftvec,
-                                  gmx::ArrayRef<const real>      nbfp,
-                                  gmx::ArrayRef<const real>      nbfp_grid,
-                                  gmx::ArrayRef<const real>      chargeA,
-                                  gmx::ArrayRef<const real>      chargeB,
-                                  gmx::ArrayRef<const int>       typeA,
-                                  gmx::ArrayRef<const int>       typeB,
-                                  t_lambda*                      fepvals,
-                                  gmx::ArrayRef<const real>      lambda,
-                                  gmx_enerdata_t*                enerd,
-                                  const gmx::StepWorkload&       stepWork,
-                                  t_nrnb*                        nrnb);
+    //! Executes the non-bonded free-energy kernels, local + non-local, always runs on the CPU
+    void dispatchFreeEnergyKernels(const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
+                                   gmx::ForceWithShiftForces*                forceWithShiftForces,
+                                   bool                                      useSimd,
+                                   int                                       ntype,
+                                   real                                      rlist,
+                                   const interaction_const_t&                ic,
+                                   gmx::ArrayRef<const gmx::RVec>            shiftvec,
+                                   gmx::ArrayRef<const real>                 nbfp,
+                                   gmx::ArrayRef<const real>                 nbfp_grid,
+                                   gmx::ArrayRef<const real>                 chargeA,
+                                   gmx::ArrayRef<const real>                 chargeB,
+                                   gmx::ArrayRef<const int>                  typeA,
+                                   gmx::ArrayRef<const int>                  typeB,
+                                   t_lambda*                                 fepvals,
+                                   gmx::ArrayRef<const real>                 lambda,
+                                   gmx_enerdata_t*                           enerd,
+                                   const gmx::StepWorkload&                  stepWork,
+                                   t_nrnb*                                   nrnb);
 
     /*! \brief Add the forces stored in nbat to f, zeros the forces in nbat
      * \param [in] locality         Local or non-local
@@ -418,6 +420,8 @@ public:
     void setupGpuShortRangeWork(const gmx::ListedForcesGpu* listedForcesGpu,
                                 gmx::InteractionLocality    iLocality) const;
 
+    void setupFepThreadedForceBuffer(int numAtomsForce);
+
     // TODO: Make all data members private
     //! All data related to the pair lists
     std::unique_ptr<PairlistSets> pairlistSets_;
@@ -429,10 +433,12 @@ public:
 private:
     //! The non-bonded setup, also affects the pairlist construction kernel
     Nbnxm::KernelSetup kernelSetup_;
+
     //! \brief Pointer to wallcycle structure.
     gmx_wallcycle* wcycle_;
-    //! Temporary array for storing foreign lambda group pair energies
-    std::unique_ptr<gmx_grppairener_t> foreignEnergyGroups_;
+
+    //! \brief The non-bonded free-energy kernel dispatcher
+    std::unique_ptr<FreeEnergyDispatch> freeEnergyDispatch_;
 
 public:
     //! GPU Nbnxm data, only used with a physical GPU (TODO: use unique_ptr)
index b279596730bb9f142cefbc613ceff9937736865c..3b448938fb52abbfff2508c5c95884b4c85a5635 100644 (file)
@@ -62,6 +62,7 @@
 #include "gromacs/utility/fatalerror.h"
 #include "gromacs/utility/logger.h"
 
+#include "freeenergydispatch.h"
 #include "grid.h"
 #include "nbnxm_geometry.h"
 #include "nbnxm_simd.h"
@@ -485,12 +486,16 @@ nonbonded_verlet_t::nonbonded_verlet_t(std::unique_ptr<PairlistSets>     pairlis
     nbat(std::move(nbat_in)),
     kernelSetup_(kernelSetup),
     wcycle_(wcycle),
-    foreignEnergyGroups_(std::make_unique<gmx_grppairener_t>(nbat->params().nenergrp)),
     gpu_nbv(gpu_nbv_ptr)
 {
     GMX_RELEASE_ASSERT(pairlistSets_, "Need valid pairlistSets");
     GMX_RELEASE_ASSERT(pairSearch_, "Need valid search object");
     GMX_RELEASE_ASSERT(nbat, "Need valid atomdata object");
+
+    if (pairlistSets_->params().haveFep)
+    {
+        freeEnergyDispatch_ = std::make_unique<FreeEnergyDispatch>(nbat->params().nenergrp);
+    }
 }
 
 nonbonded_verlet_t::~nonbonded_verlet_t()
index cd6f4324854dda3abc26c35346ff0b1c419617ab..12d5f376d0f33727800eaf78d7e68844e2f1e0da 100644 (file)
@@ -149,6 +149,7 @@ static const char* enumValuetoString(WallCycleSubCounter enumValue)
         "Nonbonded F kernel",
         "Nonbonded F clear",
         "Nonbonded FEP",
+        "Nonbonded FEP reduction",
         "Launch NB GPU tasks",
         "Launch Bonded GPU tasks",
         "Launch PME GPU tasks",
index 60669d3fbc18c38ba147d2ccdfb94e8a57f25f67..ca256607c5c26d55d8568917c833f20dae2d9d04 100644 (file)
@@ -142,6 +142,7 @@ enum class WallCycleSubCounter : int
     NonbondedKernel,
     NonbondedClear,
     NonbondedFep,
+    NonbondedFepReduction,
     LaunchGpuNonBonded,
     LaunchGpuBonded,
     LaunchGpuPme,