Merge branch 'origin/release-2020' into master
[alexxy/gromacs.git] / src / gromacs / gmxlib / nonbonded / nb_free_energy.cpp
index 19e5c660dd26bf9e9eaa0a07b502f15d26b376d0..e23eef5bacac81b46820d97095190cd3f8cfbf02 100644 (file)
@@ -3,7 +3,8 @@
  *
  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
  * Copyright (c) 2001-2004, The GROMACS development team.
- * Copyright (c) 2013,2014,2015,2016,2017,2018,2019,2020, by the GROMACS development team, led by
+ * Copyright (c) 2013,2014,2015,2016,2017 by the GROMACS development team.
+ * Copyright (c) 2018,2019,2020, by the GROMACS development team, led by
  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
  * and including many others, as listed in the AUTHORS file in the
  * top-level source directory and at http://www.gromacs.org.
@@ -38,6 +39,8 @@
 
 #include "nb_free_energy.h"
 
+#include "config.h"
+
 #include <cmath>
 
 #include <algorithm>
 #include "gromacs/math/vec.h"
 #include "gromacs/mdtypes/forceoutput.h"
 #include "gromacs/mdtypes/forcerec.h"
+#include "gromacs/mdtypes/interaction_const.h"
 #include "gromacs/mdtypes/md_enums.h"
+#include "gromacs/mdtypes/mdatom.h"
+#include "gromacs/simd/simd.h"
 #include "gromacs/utility/fatalerror.h"
 
 
-//! Enum for templating the soft-core treatment in the kernel
-enum class SoftCoreTreatment
-{
-    None,    //!< No soft-core
-    RPower6, //!< Soft-core with r-power = 6
-    RPower48 //!< Soft-core with r-power = 48
-};
-
-//! Most treatments are fine with float in mixed-precision mode.
-template<SoftCoreTreatment softCoreTreatment>
-struct SoftCoreReal
+//! Scalar (non-SIMD) data types.
+struct ScalarDataTypes
 {
-    //! Real type for soft-core calculations
-    using Real = real;
+    using RealType                     = real; //!< The data type to use as real.
+    using IntType                      = int;  //!< The data type to use as int.
+    static constexpr int simdRealWidth = 1;    //!< The width of the RealType.
+    static constexpr int simdIntWidth  = 1;    //!< The width of the IntType.
 };
 
-//! This treatment requires double precision for some computations.
-template<>
-struct SoftCoreReal<SoftCoreTreatment::RPower48>
+#if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS
+//! SIMD data types.
+struct SimdDataTypes
 {
-    //! Real type for soft-core calculations
-    using Real = double;
+    using RealType                     = gmx::SimdReal;         //!< The data type to use as real.
+    using IntType                      = gmx::SimdInt32;        //!< The data type to use as int.
+    static constexpr int simdRealWidth = GMX_SIMD_REAL_WIDTH;   //!< The width of the RealType.
+    static constexpr int simdIntWidth  = GMX_SIMD_FINT32_WIDTH; //!< The width of the IntType.
 };
+#endif
 
 //! Computes r^(1/p) and 1/r^(1/p) for the standard p=6
-template<SoftCoreTreatment softCoreTreatment>
-static inline void pthRoot(const real r, real* pthRoot, real* invPthRoot)
-{
-    *invPthRoot = gmx::invsqrt(std::cbrt(r));
-    *pthRoot    = 1 / (*invPthRoot);
-}
-
-// We need a double version to make the specialization below work
-#if !GMX_DOUBLE
-//! Computes r^(1/p) and 1/r^(1/p) for the standard p=6
-template<SoftCoreTreatment softCoreTreatment>
-static inline void pthRoot(const double r, real* pthRoot, double* invPthRoot)
+template<class RealType>
+static inline void pthRoot(const RealType r, RealType* pthRoot, RealType* invPthRoot)
 {
     *invPthRoot = gmx::invsqrt(std::cbrt(r));
     *pthRoot    = 1 / (*invPthRoot);
 }
-#endif
 
-//! Computes r^(1/p) and 1/r^(1/p) for p=48
-template<>
-inline void pthRoot<SoftCoreTreatment::RPower48>(const double r, real* pthRoot, double* invPthRoot)
+template<class RealType>
+static inline RealType calculateRinv6(const RealType rinvV)
 {
-    *pthRoot    = std::pow(r, 1.0 / 48.0);
-    *invPthRoot = 1 / (*pthRoot);
+    RealType rinv6 = rinvV * rinvV;
+    return (rinv6 * rinv6 * rinv6);
 }
 
-template<SoftCoreTreatment softCoreTreatment>
-static inline real calculateSigmaPow(const real sigma6)
-{
-    if (softCoreTreatment == SoftCoreTreatment::RPower6)
-    {
-        return sigma6;
-    }
-    else
-    {
-        real sigmaPow = sigma6 * sigma6;     /* sigma^12 */
-        sigmaPow      = sigmaPow * sigmaPow; /* sigma^24 */
-        sigmaPow      = sigmaPow * sigmaPow; /* sigma^48 */
-        return (sigmaPow);
-    }
-}
-
-template<SoftCoreTreatment softCoreTreatment, class SCReal>
-static inline real calculateRinv6(const SCReal rinvV)
-{
-    if (softCoreTreatment == SoftCoreTreatment::RPower6)
-    {
-        return rinvV;
-    }
-    else
-    {
-        real rinv6 = rinvV * rinvV;
-        return (rinv6 * rinv6 * rinv6);
-    }
-}
-
-static inline real calculateVdw6(const real c6, const real rinv6)
+template<class RealType>
+static inline RealType calculateVdw6(const RealType c6, const RealType rinv6)
 {
     return (c6 * rinv6);
 }
 
-static inline real calculateVdw12(const real c12, const real rinv6)
+template<class RealType>
+static inline RealType calculateVdw12(const RealType c12, const RealType rinv6)
 {
     return (c12 * rinv6 * rinv6);
 }
 
 /* reaction-field electrostatics */
-template<class SCReal>
-static inline SCReal
-reactionFieldScalarForce(const real qq, const real rinv, const SCReal r, const real krf, const real two)
+template<class RealType>
+static inline RealType reactionFieldScalarForce(const RealType qq,
+                                                const RealType rinv,
+                                                const RealType r,
+                                                const real     krf,
+                                                const real     two)
 {
     return (qq * (rinv - two * krf * r * r));
 }
-template<class SCReal>
-static inline real
-reactionFieldPotential(const real qq, const real rinv, const SCReal r, const real krf, const real potentialShift)
+template<class RealType>
+static inline RealType reactionFieldPotential(const RealType qq,
+                                              const RealType rinv,
+                                              const RealType r,
+                                              const real     krf,
+                                              const real     potentialShift)
 {
     return (qq * (rinv + krf * r * r - potentialShift));
 }
 
 /* Ewald electrostatics */
-static inline real ewaldScalarForce(const real coulomb, const real rinv)
+template<class RealType>
+static inline RealType ewaldScalarForce(const RealType coulomb, const RealType rinv)
 {
     return (coulomb * rinv);
 }
-static inline real ewaldPotential(const real coulomb, const real rinv, const real potentialShift)
+template<class RealType>
+static inline RealType ewaldPotential(const RealType coulomb, const RealType rinv, const real potentialShift)
 {
     return (coulomb * (rinv - potentialShift));
 }
 
 /* cutoff LJ */
-static inline real lennardJonesScalarForce(const real v6, const real v12)
+template<class RealType>
+static inline RealType lennardJonesScalarForce(const RealType v6, const RealType v12)
 {
     return (v12 - v6);
 }
-static inline real lennardJonesPotential(const real v6,
-                                         const real v12,
-                                         const real c6,
-                                         const real c12,
-                                         const real repulsionShift,
-                                         const real dispersionShift,
-                                         const real onesixth,
-                                         const real onetwelfth)
+template<class RealType>
+static inline RealType lennardJonesPotential(const RealType v6,
+                                             const RealType v12,
+                                             const RealType c6,
+                                             const RealType c12,
+                                             const real     repulsionShift,
+                                             const real     dispersionShift,
+                                             const real     onesixth,
+                                             const real     onetwelfth)
 {
     return ((v12 + c12 * repulsionShift) * onetwelfth - (v6 + c6 * dispersionShift) * onesixth);
 }
@@ -192,25 +164,28 @@ static inline real ewaldLennardJonesGridSubtract(const real c6grid, const real p
 }
 
 /* LJ Potential switch */
-template<class SCReal>
-static inline SCReal potSwitchScalarForceMod(const SCReal fScalarInp,
-                                             const real   potential,
-                                             const real   sw,
-                                             const SCReal r,
-                                             const real   rVdw,
-                                             const real   dsw,
-                                             const real   zero)
+template<class RealType>
+static inline RealType potSwitchScalarForceMod(const RealType fScalarInp,
+                                               const RealType potential,
+                                               const RealType sw,
+                                               const RealType r,
+                                               const RealType rVdw,
+                                               const RealType dsw,
+                                               const real     zero)
 {
     if (r < rVdw)
     {
-        SCReal fScalar = fScalarInp * sw - r * potential * dsw;
+        real fScalar = fScalarInp * sw - r * potential * dsw;
         return (fScalar);
     }
     return (zero);
 }
-template<class SCReal>
-static inline real
-potSwitchPotentialMod(const real potentialInp, const real sw, const SCReal r, const real rVdw, const real zero)
+template<class RealType>
+static inline RealType potSwitchPotentialMod(const RealType potentialInp,
+                                             const RealType sw,
+                                             const RealType r,
+                                             const RealType rVdw,
+                                             const real     zero)
 {
     if (r < rVdw)
     {
@@ -222,7 +197,7 @@ potSwitchPotentialMod(const real potentialInp, const real sw, const SCReal r, co
 
 
 //! Templated free-energy non-bonded kernel
-template<SoftCoreTreatment softCoreTreatment, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
+template<typename DataTypes, bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
 static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                                   rvec* gmx_restrict         xx,
                                   gmx::ForceWithShiftForces* forceWithShiftForces,
@@ -231,14 +206,14 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                                   nb_kernel_data_t* gmx_restrict kernel_data,
                                   t_nrnb* gmx_restrict nrnb)
 {
-    using SCReal = typename SoftCoreReal<softCoreTreatment>::Real;
-
-    constexpr bool useSoftCore = (softCoreTreatment != SoftCoreTreatment::None);
-
 #define STATE_A 0
 #define STATE_B 1
 #define NSTATES 2
 
+    using RealType = typename DataTypes::RealType;
+    using IntType  = typename DataTypes::IntType;
+
+    /* FIXME: How should these be handled with SIMD? */
     constexpr real onetwelfth = 1.0 / 12.0;
     constexpr real onesixth   = 1.0 / 6.0;
     constexpr real zero       = 0.0;
@@ -265,7 +240,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
     const int*  typeA         = mdatoms->typeA;
     const int*  typeB         = mdatoms->typeB;
     const int   ntype         = fr->ntype;
-    const real* nbfp          = fr->nbfp;
+    const real* nbfp          = fr->nbfp.data();
     const real* nbfp_grid     = fr->ljpme_c6grid;
     real*       Vv            = kernel_data->energygrp_vdw;
     const real  lambda_coul   = kernel_data->lambda[efptCOUL];
@@ -367,8 +342,8 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
     GMX_RELEASE_ASSERT(!(vdwInteractionTypeIsEwald && vdwModifierIsPotSwitch),
                        "Can not apply soft-core to switched Ewald potentials");
 
-    SCReal dvdl_coul = 0; /* Needs double for sc_power==48 */
-    SCReal dvdl_vdw  = 0; /* Needs double for sc_power==48 */
+    real dvdl_coul = 0;
+    real dvdl_vdw  = 0;
 
     /* Lambda factor for state A, 1-lambda*/
     real LFC[NSTATES], LFV[NSTATES];
@@ -385,7 +360,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
     DLF[STATE_B] = 1;
 
     real           lfac_coul[NSTATES], dlfac_coul[NSTATES], lfac_vdw[NSTATES], dlfac_vdw[NSTATES];
-    constexpr real sc_r_power = (softCoreTreatment == SoftCoreTreatment::RPower48 ? 48.0_real : 6.0_real);
+    constexpr real sc_r_power = 6.0_real;
     for (int i = 0; i < NSTATES; i++)
     {
         lfac_coul[i]  = (lam_power == 2 ? (1 - LFC[i]) * (1 - LFC[i]) : (1 - LFC[i]));
@@ -426,25 +401,31 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
 
         for (int k = nj0; k < nj1; k++)
         {
-            int        tj[NSTATES];
-            const int  jnr = jjnr[k];
-            const int  j3  = 3 * jnr;
-            real       c6[NSTATES], c12[NSTATES], qq[NSTATES], Vcoul[NSTATES], Vvdw[NSTATES];
-            real       r, rinv, rp, rpm2;
-            real       alpha_vdw_eff, alpha_coul_eff, sigma_pow[NSTATES];
-            const real dx  = ix - x[j3];
-            const real dy  = iy - x[j3 + 1];
-            const real dz  = iz - x[j3 + 2];
-            const real rsq = dx * dx + dy * dy + dz * dz;
-            SCReal     FscalC[NSTATES], FscalV[NSTATES]; /* Needs double for sc_power==48 */
-
-            if (rsq >= rcutoff_max2)
+            int            tj[NSTATES];
+            const int      jnr = jjnr[k];
+            const int      j3  = 3 * jnr;
+            RealType       c6[NSTATES], c12[NSTATES], qq[NSTATES], Vcoul[NSTATES], Vvdw[NSTATES];
+            RealType       r, rinv, rp, rpm2;
+            RealType       alpha_vdw_eff, alpha_coul_eff, sigma6[NSTATES];
+            const RealType dx  = ix - x[j3];
+            const RealType dy  = iy - x[j3 + 1];
+            const RealType dz  = iz - x[j3 + 2];
+            const RealType rsq = dx * dx + dy * dy + dz * dz;
+            RealType       FscalC[NSTATES], FscalV[NSTATES];
+            /* Check if this pair on the exlusions list.*/
+            const bool bPairIncluded = nlist->excl_fep == nullptr || nlist->excl_fep[k];
+
+            if (rsq >= rcutoff_max2 && bPairIncluded)
             {
                 /* We save significant time by skipping all code below.
                  * Note that with soft-core interactions, the actual cut-off
                  * check might be different. But since the soft-core distance
                  * is always larger than r, checking on r here is safe.
+                 * Exclusions outside the cutoff can not be skipped as
+                 * when using Ewald: the reciprocal-space
+                 * Ewald component still needs to be subtracted.
                  */
+
                 continue;
             }
             npair_within_cutoff++;
@@ -470,7 +451,12 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                 r    = 0;
             }
 
-            if (softCoreTreatment == SoftCoreTreatment::None)
+            if (useSoftCore)
+            {
+                rpm2 = rsq * rsq;  /* r4 */
+                rp   = rpm2 * rsq; /* r6 */
+            }
+            else
             {
                 /* The soft-core power p will not affect the results
                  * with not using soft-core, so we use power of 0 which gives
@@ -479,21 +465,8 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                 rpm2 = rinv * rinv;
                 rp   = 1;
             }
-            if (softCoreTreatment == SoftCoreTreatment::RPower6)
-            {
-                rpm2 = rsq * rsq;  /* r4 */
-                rp   = rpm2 * rsq; /* r6 */
-            }
-            if (softCoreTreatment == SoftCoreTreatment::RPower48)
-            {
-                rp   = rsq * rsq * rsq; /* r6 */
-                rp   = rp * rp;         /* r12 */
-                rp   = rp * rp;         /* r24 */
-                rp   = rp * rp;         /* r48 */
-                rpm2 = rp / rsq;        /* r46 */
-            }
 
-            real Fscal = 0;
+            RealType Fscal = 0;
 
             qq[STATE_A] = iqA * chargeA[jnr];
             qq[STATE_B] = iqB * chargeB[jnr];
@@ -501,7 +474,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
             tj[STATE_A] = ntiA + 2 * typeA[jnr];
             tj[STATE_B] = ntiB + 2 * typeB[jnr];
 
-            if (nlist->excl_fep == nullptr || nlist->excl_fep[k])
+            if (bPairIncluded)
             {
                 c6[STATE_A] = nbfp[tj[STATE_A]];
                 c6[STATE_B] = nbfp[tj[STATE_B]];
@@ -511,7 +484,6 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                     c12[i] = nbfp[tj[i] + 1];
                     if (useSoftCore)
                     {
-                        real sigma6[NSTATES];
                         if ((c6[i] > 0) && (c12[i] > 0))
                         {
                             /* c12 is stored scaled with 12.0 and c6 is scaled with 6.0 - correct for this */
@@ -525,7 +497,6 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                         {
                             sigma6[i] = sigma6_def;
                         }
-                        sigma_pow[i] = calculateSigmaPow<softCoreTreatment>(sigma6[i]);
                     }
                 }
 
@@ -551,21 +522,20 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                     Vcoul[i]  = 0;
                     Vvdw[i]   = 0;
 
-                    real   rinvC, rinvV;
-                    SCReal rC, rV, rpinvC, rpinvV; /* Needs double for sc_power==48 */
+                    RealType rinvC, rinvV, rC, rV, rpinvC, rpinvV;
 
                     /* Only spend time on A or B state if it is non-zero */
                     if ((qq[i] != 0) || (c6[i] != 0) || (c12[i] != 0))
                     {
-                        /* this section has to be inside the loop because of the dependence on sigma_pow */
+                        /* this section has to be inside the loop because of the dependence on sigma6 */
                         if (useSoftCore)
                         {
-                            rpinvC = one / (alpha_coul_eff * lfac_coul[i] * sigma_pow[i] + rp);
-                            pthRoot<softCoreTreatment>(rpinvC, &rinvC, &rC);
+                            rpinvC = one / (alpha_coul_eff * lfac_coul[i] * sigma6[i] + rp);
+                            pthRoot(rpinvC, &rinvC, &rC);
                             if (scLambdasOrAlphasDiffer)
                             {
-                                rpinvV = one / (alpha_vdw_eff * lfac_vdw[i] * sigma_pow[i] + rp);
-                                pthRoot<softCoreTreatment>(rpinvV, &rinvV, &rV);
+                                rpinvV = one / (alpha_vdw_eff * lfac_vdw[i] * sigma6[i] + rp);
+                                pthRoot(rpinvV, &rinvV, &rV);
                             }
                             else
                             {
@@ -616,17 +586,17 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                                                      || (!vdwInteractionTypeIsEwald && rV < rvdw);
                         if ((c6[i] != 0 || c12[i] != 0) && computeVdwInteraction)
                         {
-                            real rinv6;
-                            if (softCoreTreatment == SoftCoreTreatment::RPower6)
+                            RealType rinv6;
+                            if (useSoftCore)
                             {
-                                rinv6 = calculateRinv6<softCoreTreatment>(rpinvV);
+                                rinv6 = rpinvV;
                             }
                             else
                             {
-                                rinv6 = calculateRinv6<softCoreTreatment>(rinvV);
+                                rinv6 = calculateRinv6(rinvV);
                             }
-                            real Vvdw6  = calculateVdw6(c6[i], rinv6);
-                            real Vvdw12 = calculateVdw12(c12[i], rinv6);
+                            RealType Vvdw6  = calculateVdw6(c6[i], rinv6);
+                            RealType Vvdw12 = calculateVdw12(c12[i], rinv6);
 
                             Vvdw[i] = lennardJonesPotential(Vvdw6, Vvdw12, c6[i], c12[i], repulsionShift,
                                                             dispersionShift, onesixth, onetwelfth);
@@ -641,11 +611,12 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
 
                             if (vdwModifierIsPotSwitch)
                             {
-                                real d        = rV - ic->rvdw_switch;
-                                d             = (d > zero) ? d : zero;
-                                const real d2 = d * d;
-                                const real sw = one + d2 * d * (vdw_swV3 + d * (vdw_swV4 + d * vdw_swV5));
-                                const real dsw = d2 * (vdw_swF2 + d * (vdw_swF3 + d * vdw_swF4));
+                                RealType d        = rV - ic->rvdw_switch;
+                                d                 = (d > zero) ? d : zero;
+                                const RealType d2 = d * d;
+                                const RealType sw =
+                                        one + d2 * d * (vdw_swV3 + d * (vdw_swV4 + d * vdw_swV5));
+                                const RealType dsw = d2 * (vdw_swF2 + d * (vdw_swF3 + d * vdw_swF4));
 
                                 FscalV[i] = potSwitchScalarForceMod(FscalV[i], Vvdw[i], sw, rV,
                                                                     rvdw, dsw, zero);
@@ -661,7 +632,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                         FscalC[i] *= rpinvC;
                         FscalV[i] *= rpinvV;
                     }
-                }
+                } // end for (int i = 0; i < NSTATES; i++)
 
                 /* Assemble A and B states */
                 for (int i = 0; i < NSTATES; i++)
@@ -674,11 +645,10 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
 
                     if (useSoftCore)
                     {
-                        dvdl_coul +=
-                                Vcoul[i] * DLF[i]
-                                + LFC[i] * alpha_coul_eff * dlfac_coul[i] * FscalC[i] * sigma_pow[i];
+                        dvdl_coul += Vcoul[i] * DLF[i]
+                                     + LFC[i] * alpha_coul_eff * dlfac_coul[i] * FscalC[i] * sigma6[i];
                         dvdl_vdw += Vvdw[i] * DLF[i]
-                                    + LFV[i] * alpha_vdw_eff * dlfac_vdw[i] * FscalV[i] * sigma_pow[i];
+                                    + LFV[i] * alpha_vdw_eff * dlfac_vdw[i] * FscalV[i] * sigma6[i];
                     }
                     else
                     {
@@ -686,7 +656,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                         dvdl_vdw += Vvdw[i] * DLF[i];
                     }
                 }
-            }
+            } // end if (bPairIncluded)
             else if (icoul == GMX_NBKERNEL_ELEC_REACTIONFIELD)
             {
                 /* For excluded pairs, which are only in this pair list when
@@ -694,7 +664,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                  * As there is no singularity, there is no need for soft-core.
                  */
                 const real FF = -two * krf;
-                real       VV = krf * rsq - crf;
+                RealType   VV = krf * rsq - crf;
 
                 if (ii == jnr)
                 {
@@ -709,7 +679,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                 }
             }
 
-            if (elecInteractionTypeIsEwald && r < rcoulomb)
+            if (elecInteractionTypeIsEwald && (r < rcoulomb || !bPairIncluded))
             {
                 /* See comment in the preamble. When using Ewald interactions
                  * (unless we use a switch modifier) we subtract the reciprocal-space
@@ -721,11 +691,11 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                  */
                 real v_lr, f_lr;
 
-                const real ewrt   = r * coulombTableScale;
-                int        ewitab = static_cast<int>(ewrt);
-                const real eweps  = ewrt - ewitab;
-                ewitab            = 4 * ewitab;
-                f_lr              = ewtab[ewitab] + eweps * ewtab[ewitab + 1];
+                const RealType ewrt   = r * coulombTableScale;
+                IntType        ewitab = static_cast<IntType>(ewrt);
+                const RealType eweps  = ewrt - ewitab;
+                ewitab                = 4 * ewitab;
+                f_lr                  = ewtab[ewitab] + eweps * ewtab[ewitab + 1];
                 v_lr = (ewtab[ewitab + 2] - coulombTableScaleInvHalf * eweps * (ewtab[ewitab] + f_lr));
                 f_lr *= rinv;
 
@@ -766,16 +736,17 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                  * r close to 0 for non-interacting pairs.
                  */
 
-                const real rs   = rsq * rinv * vdwTableScale;
-                const int  ri   = static_cast<int>(rs);
-                const real frac = rs - ri;
-                const real f_lr = (1 - frac) * tab_ewald_F_lj[ri] + frac * tab_ewald_F_lj[ri + 1];
+                const RealType rs   = rsq * rinv * vdwTableScale;
+                const IntType  ri   = static_cast<IntType>(rs);
+                const RealType frac = rs - ri;
+                const RealType f_lr = (1 - frac) * tab_ewald_F_lj[ri] + frac * tab_ewald_F_lj[ri + 1];
                 /* TODO: Currently the Ewald LJ table does not contain
                  * the factor 1/6, we should add this.
                  */
-                const real FF = f_lr * rinv / six;
-                real VV = (tab_ewald_V_lj[ri] - vdwTableScaleInvHalf * frac * (tab_ewald_F_lj[ri] + f_lr))
-                          / six;
+                const RealType FF = f_lr * rinv / six;
+                RealType       VV =
+                        (tab_ewald_V_lj[ri] - vdwTableScaleInvHalf * frac * (tab_ewald_F_lj[ri] + f_lr))
+                        / six;
 
                 if (ii == jnr)
                 {
@@ -818,7 +789,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
 #pragma omp atomic
                 f[j3 + 2] -= tz;
             }
-        }
+        } // end for (int k = nj0; k < nj1; k++)
 
         /* The atomics below are expensive with many OpenMP threads.
          * Here unperturbed i-particles will usually only have a few
@@ -854,7 +825,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                 Vv[ggid] += vvtot;
             }
         }
-    }
+    } // end for (int n = 0; n < nri; n++)
 
 #pragma omp atomic
     dvdl[efptCOUL] += dvdl_coul;
@@ -877,69 +848,93 @@ typedef void (*KernelFunction)(const t_nblist* gmx_restrict nlist,
                                nb_kernel_data_t* gmx_restrict kernel_data,
                                t_nrnb* gmx_restrict nrnb);
 
-template<SoftCoreTreatment softCoreTreatment, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald>
-static KernelFunction dispatchKernelOnVdwModifier(const bool vdwModifierIsPotSwitch)
+template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
+static KernelFunction dispatchKernelOnUseSimd(const bool useSimd)
+{
+    if (useSimd)
+    {
+#if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS && GMX_USE_SIMD_KERNELS
+        /* FIXME: Here SimdDataTypes should be used to enable SIMD. So far, the code in nb_free_energy_kernel is not adapted to SIMD */
+        return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald,
+                                      elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
+#else
+        return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald,
+                                      elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
+#endif
+    }
+    else
+    {
+        return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald,
+                                      elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
+    }
+}
+
+template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald>
+static KernelFunction dispatchKernelOnVdwModifier(const bool vdwModifierIsPotSwitch, const bool useSimd)
 {
     if (vdwModifierIsPotSwitch)
     {
-        return (nb_free_energy_kernel<softCoreTreatment, scLambdasOrAlphasDiffer,
-                                      vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, true>);
+        return (dispatchKernelOnUseSimd<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald,
+                                        elecInteractionTypeIsEwald, true>(useSimd));
     }
     else
     {
-        return (nb_free_energy_kernel<softCoreTreatment, scLambdasOrAlphasDiffer,
-                                      vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, false>);
+        return (dispatchKernelOnUseSimd<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald,
+                                        elecInteractionTypeIsEwald, false>(useSimd));
     }
 }
 
-template<SoftCoreTreatment softCoreTreatment, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald>
+template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald>
 static KernelFunction dispatchKernelOnElecInteractionType(const bool elecInteractionTypeIsEwald,
-                                                          const bool vdwModifierIsPotSwitch)
+                                                          const bool vdwModifierIsPotSwitch,
+                                                          const bool useSimd)
 {
     if (elecInteractionTypeIsEwald)
     {
-        return (dispatchKernelOnVdwModifier<softCoreTreatment, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, true>(
-                vdwModifierIsPotSwitch));
+        return (dispatchKernelOnVdwModifier<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, true>(
+                vdwModifierIsPotSwitch, useSimd));
     }
     else
     {
-        return (dispatchKernelOnVdwModifier<softCoreTreatment, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, false>(
-                vdwModifierIsPotSwitch));
+        return (dispatchKernelOnVdwModifier<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, false>(
+                vdwModifierIsPotSwitch, useSimd));
     }
 }
 
-template<SoftCoreTreatment softCoreTreatment, bool scLambdasOrAlphasDiffer>
+template<bool useSoftCore, bool scLambdasOrAlphasDiffer>
 static KernelFunction dispatchKernelOnVdwInteractionType(const bool vdwInteractionTypeIsEwald,
                                                          const bool elecInteractionTypeIsEwald,
-                                                         const bool vdwModifierIsPotSwitch)
+                                                         const bool vdwModifierIsPotSwitch,
+                                                         const bool useSimd)
 {
     if (vdwInteractionTypeIsEwald)
     {
-        return (dispatchKernelOnElecInteractionType<softCoreTreatment, scLambdasOrAlphasDiffer, true>(
-                elecInteractionTypeIsEwald, vdwModifierIsPotSwitch));
+        return (dispatchKernelOnElecInteractionType<useSoftCore, scLambdasOrAlphasDiffer, true>(
+                elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
     }
     else
     {
-        return (dispatchKernelOnElecInteractionType<softCoreTreatment, scLambdasOrAlphasDiffer, false>(
-                elecInteractionTypeIsEwald, vdwModifierIsPotSwitch));
+        return (dispatchKernelOnElecInteractionType<useSoftCore, scLambdasOrAlphasDiffer, false>(
+                elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
     }
 }
 
-template<SoftCoreTreatment softCoreTreatment>
+template<bool useSoftCore>
 static KernelFunction dispatchKernelOnScLambdasOrAlphasDifference(const bool scLambdasOrAlphasDiffer,
                                                                   const bool vdwInteractionTypeIsEwald,
                                                                   const bool elecInteractionTypeIsEwald,
-                                                                  const bool vdwModifierIsPotSwitch)
+                                                                  const bool vdwModifierIsPotSwitch,
+                                                                  const bool useSimd)
 {
     if (scLambdasOrAlphasDiffer)
     {
-        return (dispatchKernelOnVdwInteractionType<softCoreTreatment, true>(
-                vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch));
+        return (dispatchKernelOnVdwInteractionType<useSoftCore, true>(
+                vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
     }
     else
     {
-        return (dispatchKernelOnVdwInteractionType<softCoreTreatment, false>(
-                vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch));
+        return (dispatchKernelOnVdwInteractionType<useSoftCore, false>(
+                vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
     }
 }
 
@@ -947,25 +942,20 @@ static KernelFunction dispatchKernel(const bool        scLambdasOrAlphasDiffer,
                                      const bool        vdwInteractionTypeIsEwald,
                                      const bool        elecInteractionTypeIsEwald,
                                      const bool        vdwModifierIsPotSwitch,
+                                     const bool        useSimd,
                                      const t_forcerec* fr)
 {
     if (fr->sc_alphacoul == 0 && fr->sc_alphavdw == 0)
     {
-        return (dispatchKernelOnScLambdasOrAlphasDifference<SoftCoreTreatment::None>(
+        return (dispatchKernelOnScLambdasOrAlphasDifference<false>(
                 scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald,
-                vdwModifierIsPotSwitch));
-    }
-    else if (fr->sc_r_power == 6.0_real)
-    {
-        return (dispatchKernelOnScLambdasOrAlphasDifference<SoftCoreTreatment::RPower6>(
-                scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald,
-                vdwModifierIsPotSwitch));
+                vdwModifierIsPotSwitch, useSimd));
     }
     else
     {
-        return (dispatchKernelOnScLambdasOrAlphasDifference<SoftCoreTreatment::RPower48>(
+        return (dispatchKernelOnScLambdasOrAlphasDifference<true>(
                 scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald,
-                vdwModifierIsPotSwitch));
+                vdwModifierIsPotSwitch, useSimd));
     }
 }
 
@@ -985,12 +975,13 @@ void gmx_nb_free_energy_kernel(const t_nblist*            nlist,
     const bool elecInteractionTypeIsEwald = (EEL_PME_EWALD(fr->ic->eeltype));
     const bool vdwModifierIsPotSwitch     = (fr->ic->vdw_modifier == eintmodPOTSWITCH);
     bool       scLambdasOrAlphasDiffer    = true;
+    const bool useSimd                    = fr->use_simd_kernels;
 
     if (fr->sc_alphacoul == 0 && fr->sc_alphavdw == 0)
     {
         scLambdasOrAlphasDiffer = false;
     }
-    else if (fr->sc_r_power == 6.0_real || fr->sc_r_power == 48.0_real)
+    else if (fr->sc_r_power == 6.0_real)
     {
         if (kernel_data->lambda[efptCOUL] == kernel_data->lambda[efptVDW] && fr->sc_alphacoul == fr->sc_alphavdw)
         {
@@ -1001,7 +992,9 @@ void gmx_nb_free_energy_kernel(const t_nblist*            nlist,
     {
         GMX_RELEASE_ASSERT(false, "Unsupported soft-core r-power");
     }
-    KernelFunction kernelFunc = dispatchKernel(scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald,
-                                               elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, fr);
+
+    KernelFunction kernelFunc;
+    kernelFunc = dispatchKernel(scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald,
+                                elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd, fr);
     kernelFunc(nlist, xx, ff, fr, mdatoms, kernel_data, nrnb);
 }