Avoid overflows in free-energy kernel
authorBerk Hess <hess@kth.se>
Fri, 1 Oct 2021 21:01:24 +0000 (21:01 +0000)
committerBerk Hess <hess@kth.se>
Fri, 1 Oct 2021 21:01:24 +0000 (21:01 +0000)
src/gromacs/gmxlib/nonbonded/nb_free_energy.cpp

index dbb17010c2ea80573c290e2328d9c0802eb39cef..c8e46d15d3a5cb381c876ddc79941fc6fabf9a39 100644 (file)
@@ -90,46 +90,71 @@ struct SimdDataTypes
 };
 #endif
 
-template<class RealType, class BoolType>
+/*! \brief Lower limit for square interaction distances in nonbonded kernels.
+ *
+ * This is a mimimum on r^2 to avoid overflows when computing r^6.
+ * This will only affect results for soft-cored interaction at distances smaller
+ * than 1e-6 and will limit extremely high foreign energies for overlapping atoms.
+ * Note that we could use a somewhat smaller minimum in double precision.
+ * But because invsqrt in double precision can use single precision, this number
+ * can not be much smaller, we use the same number for simplicity.
+ */
+constexpr real c_minDistanceSquared = 1.0e-12_real;
+
+/*! \brief Higher limit for r^-6 used for Lennard-Jones interactions
+ *
+ * This is needed to avoid overflow of LJ energy and force terms for excluded
+ * atoms and foreign energies of hard-core states of overlapping atoms.
+ * Note that in single precision this value leaves room for C12 coefficients up to 3.4e8.
+ */
+constexpr real c_maxRInvSix = 1.0e15_real;
+
+template<bool computeForces, class RealType>
 static inline void
-pmeCoulombCorrectionVF(const RealType rSq, const real beta, RealType* pot, RealType* force, const BoolType mask)
+pmeCoulombCorrectionVF(const RealType rSq, const real beta, RealType* pot, RealType gmx_unused* force)
 {
-    const RealType brsq = gmx::selectByMask(rSq * beta * beta, mask);
-    *force              = -brsq * beta * gmx::pmeForceCorrection(brsq);
-    *pot                = beta * gmx::pmePotentialCorrection(brsq);
+    const RealType brsq = rSq * beta * beta;
+    if constexpr (computeForces)
+    {
+        *force = -brsq * beta * gmx::pmeForceCorrection(brsq);
+    }
+    *pot = beta * gmx::pmePotentialCorrection(brsq);
 }
 
-template<class RealType, class BoolType>
+template<bool computeForces, class RealType, class BoolType>
 static inline void pmeLJCorrectionVF(const RealType rInv,
                                      const RealType rSq,
                                      const real     ewaldLJCoeffSq,
                                      const real     ewaldLJCoeffSixDivSix,
                                      RealType*      pot,
-                                     RealType*      force,
-                                     const BoolType mask,
-                                     const BoolType bIiEqJnr)
+                                     RealType gmx_unused* force,
+                                     const BoolType       mask,
+                                     const BoolType       bIiEqJnr)
 {
     // We mask rInv to get zero force and potential for masked out pair interactions
-    const RealType rInvSq  = gmx::selectByMask(rInv * rInv, mask);
+    const RealType rInvSq  = rInv * rInv;
     const RealType rInvSix = rInvSq * rInvSq * rInvSq;
     // Mask rSq to avoid underflow in exp()
     const RealType coeffSqRSq       = ewaldLJCoeffSq * gmx::selectByMask(rSq, mask);
     const RealType expNegCoeffSqRSq = gmx::exp(-coeffSqRSq);
     const RealType poly             = 1.0_real + coeffSqRSq + 0.5_real * coeffSqRSq * coeffSqRSq;
-    *force = rInvSix - expNegCoeffSqRSq * (rInvSix * poly + ewaldLJCoeffSixDivSix);
-    *force = *force * rInvSq;
+    if constexpr (computeForces)
+    {
+        *force = rInvSix - expNegCoeffSqRSq * (rInvSix * poly + ewaldLJCoeffSixDivSix);
+        *force = *force * rInvSq;
+    }
     // The self interaction is the limit for r -> 0 which we need to compute separately
     *pot = gmx::blend(
             rInvSix * (1.0_real - expNegCoeffSqRSq * poly), 0.5_real * ewaldLJCoeffSixDivSix, bIiEqJnr);
 }
 
-//! Computes r^(1/p) and 1/r^(1/p) for the standard p=6
-template<class RealType, class BoolType>
-static inline void pthRoot(const RealType r, RealType* pthRoot, RealType* invPthRoot, const BoolType mask)
+//! Computes r^(1/6) and 1/r^(1/6)
+template<class RealType>
+static inline void sixthRoot(const RealType r, RealType* sixthRoot, RealType* invSixthRoot)
 {
     RealType cbrtRes = gmx::cbrt(r);
-    *invPthRoot      = gmx::maskzInvsqrt(cbrtRes, mask);
-    *pthRoot         = gmx::maskzInv(*invPthRoot, mask);
+    *invSixthRoot    = gmx::invsqrt(cbrtRes);
+    *sixthRoot       = gmx::inv(*invSixthRoot);
 }
 
 template<class RealType>
@@ -232,7 +257,7 @@ static inline RealType potSwitchPotentialMod(const RealType potentialInp, const
 
 
 //! Templated free-energy non-bonded kernel
-template<typename DataTypes, bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
+template<typename DataTypes, bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch, bool computeForces>
 static void nb_free_energy_kernel(const t_nblist&                                  nlist,
                                   const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
                                   const int                                        ntype,
@@ -249,10 +274,10 @@ static void nb_free_energy_kernel(const t_nblist&
                                   gmx::ArrayRef<const real>                        lambda,
                                   t_nrnb* gmx_restrict                             nrnb,
                                   gmx::ArrayRefWithPadding<gmx::RVec> threadForceBuffer,
-                                  rvec*                               threadForceShiftBuffer,
-                                  gmx::ArrayRef<real>                 threadVc,
-                                  gmx::ArrayRef<real>                 threadVv,
-                                  gmx::ArrayRef<real>                 threadDvdl)
+                                  rvec gmx_unused*    threadForceShiftBuffer,
+                                  gmx::ArrayRef<real> threadVc,
+                                  gmx::ArrayRef<real> threadVv,
+                                  gmx::ArrayRef<real> threadDvdl)
 {
 #define STATE_A 0
 #define STATE_B 1
@@ -286,8 +311,7 @@ static void nb_free_energy_kernel(const t_nblist&
     const real            lam_power     = scParams.lambdaPower;
     const real gmx_unused sigma6_def    = scParams.sigma6WithInvalidSigma;
     const real gmx_unused sigma6_min    = scParams.sigma6Minimum;
-    const bool            doForces      = ((flags & GMX_NONBONDED_DO_FORCE) != 0);
-    const bool            doShiftForces = ((flags & GMX_NONBONDED_DO_SHIFTFORCE) != 0);
+    const bool gmx_unused doShiftForces = ((flags & GMX_NONBONDED_DO_SHIFTFORCE) != 0);
     const bool            doPotential   = ((flags & GMX_NONBONDED_DO_POTENTIAL) != 0);
 
     // Extract data from interaction_const_t
@@ -364,6 +388,10 @@ static void nb_free_energy_kernel(const t_nblist&
     GMX_RELEASE_ASSERT(!(vdwInteractionTypeIsEwald && vdwModifierIsPotSwitch),
                        "Can not apply soft-core to switched Ewald potentials");
 
+    const RealType            minDistanceSquared(c_minDistanceSquared);
+    const RealType            maxRInvSix(c_maxRInvSix);
+    const RealType gmx_unused floatMin(GMX_FLOAT_MIN);
+
     RealType dvdlCoul(zero);
     RealType dvdlVdw(zero);
 
@@ -392,8 +420,19 @@ static void nb_free_energy_kernel(const t_nblist&
     }
 
     // We need pointers to real for SIMD access
-    const real* gmx_restrict x            = coords.paddedConstArrayRef().data()[0];
-    real* gmx_restrict       forceRealPtr = threadForceBuffer.paddedArrayRef().data()[0];
+    const real* gmx_restrict x = coords.paddedConstArrayRef().data()[0];
+    real* gmx_restrict       forceRealPtr;
+    if constexpr (computeForces)
+    {
+        GMX_ASSERT(nri == 0 || !threadForceBuffer.empty(), "need a valid threadForceBuffer");
+
+        forceRealPtr = threadForceBuffer.paddedArrayRef().data()[0];
+
+        if (doShiftForces)
+        {
+            GMX_ASSERT(threadForceShiftBuffer != nullptr, "need a valid threadForceShiftBuffer");
+        }
+    }
 
     const real rlistSquared = gmx::square(rlist);
 
@@ -425,11 +464,11 @@ static void nb_free_energy_kernel(const t_nblist&
         RealType   fIZ(0);
 
 #if GMX_SIMD_HAVE_REAL
-        alignas(GMX_SIMD_ALIGNMENT) int preloadIi[DataTypes::simdRealWidth];
-        alignas(GMX_SIMD_ALIGNMENT) int preloadIs[DataTypes::simdRealWidth];
+        alignas(GMX_SIMD_ALIGNMENT) int            preloadIi[DataTypes::simdRealWidth];
+        alignas(GMX_SIMD_ALIGNMENT) int gmx_unused preloadIs[DataTypes::simdRealWidth];
 #else
-        int preloadIi[DataTypes::simdRealWidth];
-        int preloadIs[DataTypes::simdRealWidth];
+        int            preloadIi[DataTypes::simdRealWidth];
+        int gmx_unused preloadIs[DataTypes::simdRealWidth];
 #endif
         for (int s = 0; s < DataTypes::simdRealWidth; s++)
         {
@@ -554,7 +593,7 @@ static void nb_free_energy_kernel(const t_nblist&
             const RealType dX  = ix - jx;
             const RealType dY  = iy - jy;
             const RealType dZ  = iz - jz;
-            const RealType rSq = dX * dX + dY * dY + dZ * dZ;
+            RealType       rSq = dX * dX + dY * dY + dZ * dZ;
 
             BoolType withinCutoffMask = (rSq < rcutoff_max2);
 
@@ -606,13 +645,9 @@ static void nb_free_energy_kernel(const t_nblist&
                 alphaCoulEff = gmx::load<RealType>(preloadAlphaCoulEff);
             }
 
-            BoolType rSqValid = (zero < rSq);
-
-            /* The force at r=0 is zero, because of symmetry.
-             * But note that the potential is in general non-zero,
-             * since the soft-cored r will be non-zero.
-             */
-            rInv = gmx::maskzInvsqrt(rSq, rSqValid);
+            // Avoid overflow of r^-12 at distances near zero
+            rSq  = gmx::max(rSq, minDistanceSquared);
+            rInv = gmx::invsqrt(rSq);
             r    = rSq * rInv;
 
             RealType gmx_unused rp, rpm2;
@@ -637,8 +672,8 @@ static void nb_free_energy_kernel(const t_nblist&
              * bPairIncluded is true then withinCutoffMask must also be true. */
             if (gmx::anyTrue(withinCutoffMask && bPairIncluded))
             {
-                RealType fScalC[NSTATES], fScalV[NSTATES];
-                RealType vCoul[NSTATES], vVdw[NSTATES];
+                RealType gmx_unused fScalC[NSTATES], fScalV[NSTATES];
+                RealType            vCoul[NSTATES], vVdw[NSTATES];
                 for (int i = 0; i < NSTATES; i++)
                 {
                     fScalC[i] = zero;
@@ -656,17 +691,15 @@ static void nb_free_energy_kernel(const t_nblist&
                     {
                         if constexpr (useSoftCore)
                         {
-                            RealType divisor      = (alphaCoulEff * lFacCoul[i] * sigma6[i] + rp);
-                            BoolType validDivisor = (zero < divisor);
-                            rPInvC                = gmx::maskzInv(divisor, validDivisor);
-                            pthRoot(rPInvC, &rInvC, &rC, validDivisor);
+                            RealType divisor = (alphaCoulEff * lFacCoul[i] * sigma6[i] + rp);
+                            rPInvC           = gmx::inv(divisor);
+                            sixthRoot(rPInvC, &rInvC, &rC);
 
                             if constexpr (scLambdasOrAlphasDiffer)
                             {
-                                RealType divisor      = (alphaVdwEff * lFacVdw[i] * sigma6[i] + rp);
-                                BoolType validDivisor = (zero < divisor);
-                                rPInvV                = gmx::maskzInv(divisor, validDivisor);
-                                pthRoot(rPInvV, &rInvV, &rV, validDivisor);
+                                RealType divisor = (alphaVdwEff * lFacVdw[i] * sigma6[i] + rp);
+                                rPInvV           = gmx::inv(divisor);
+                                sixthRoot(rPInvV, &rInvV, &rV);
                             }
                             else
                             {
@@ -704,17 +737,26 @@ static void nb_free_energy_kernel(const t_nblist&
                         {
                             if constexpr (elecInteractionTypeIsEwald)
                             {
-                                vCoul[i]  = ewaldPotential(qq[i], rInvC, sh_ewald);
-                                fScalC[i] = ewaldScalarForce(qq[i], rInvC);
+                                vCoul[i] = ewaldPotential(qq[i], rInvC, sh_ewald);
+                                if constexpr (computeForces)
+                                {
+                                    fScalC[i] = ewaldScalarForce(qq[i], rInvC);
+                                }
                             }
                             else
                             {
-                                vCoul[i]  = reactionFieldPotential(qq[i], rInvC, rC, krf, crf);
-                                fScalC[i] = reactionFieldScalarForce(qq[i], rInvC, rC, krf, two);
+                                vCoul[i] = reactionFieldPotential(qq[i], rInvC, rC, krf, crf);
+                                if constexpr (computeForces)
+                                {
+                                    fScalC[i] = reactionFieldScalarForce(qq[i], rInvC, rC, krf, two);
+                                }
                             }
 
-                            vCoul[i]  = gmx::selectByMask(vCoul[i], computeElecInteraction);
-                            fScalC[i] = gmx::selectByMask(fScalC[i], computeElecInteraction);
+                            vCoul[i] = gmx::selectByMask(vCoul[i], computeElecInteraction);
+                            if constexpr (computeForces)
+                            {
+                                fScalC[i] = gmx::selectByMask(fScalC[i], computeElecInteraction);
+                            }
                         }
 
                         /* Only process the VDW interactions if we either
@@ -743,12 +785,21 @@ static void nb_free_energy_kernel(const t_nblist&
                             {
                                 rInv6 = calculateRinv6(rInvV);
                             }
+                            // Avoid overflow at short distance for masked exclusions and
+                            // for foreign energy calculations at a hard core end state.
+                            // Note that we should limit r^-6, and thus also r^-12, and
+                            // not only r^-12, as that could lead to erroneously low instead
+                            // of very high foreign energies.
+                            rInv6           = gmx::min(rInv6, maxRInvSix);
                             RealType vVdw6  = calculateVdw6(c6[i], rInv6);
                             RealType vVdw12 = calculateVdw12(c12[i], rInv6);
 
                             vVdw[i] = lennardJonesPotential(
                                     vVdw6, vVdw12, c6[i], c12[i], repulsionShift, dispersionShift, oneSixth, oneTwelfth);
-                            fScalV[i] = lennardJonesScalarForce(vVdw6, vVdw12);
+                            if constexpr (computeForces)
+                            {
+                                fScalV[i] = lennardJonesScalarForce(vVdw6, vVdw12);
+                            }
 
                             if constexpr (vdwInteractionTypeIsEwald)
                             {
@@ -768,24 +819,33 @@ static void nb_free_energy_kernel(const t_nblist&
                                 const RealType d2      = d * d;
                                 const RealType sw =
                                         one + d2 * d * (vdw_swV3 + d * (vdw_swV4 + d * vdw_swV5));
-                                const RealType dsw = d2 * (vdw_swF2 + d * (vdw_swF3 + d * vdw_swF4));
 
-                                fScalV[i] = potSwitchScalarForceMod(
-                                        fScalV[i], vVdw[i], sw, rV, dsw, potSwitchMask);
+                                if constexpr (computeForces)
+                                {
+                                    const RealType dsw = d2 * (vdw_swF2 + d * (vdw_swF3 + d * vdw_swF4));
+                                    fScalV[i]          = potSwitchScalarForceMod(
+                                            fScalV[i], vVdw[i], sw, rV, dsw, potSwitchMask);
+                                }
                                 vVdw[i] = potSwitchPotentialMod(vVdw[i], sw, potSwitchMask);
                             }
 
-                            vVdw[i]   = gmx::selectByMask(vVdw[i], computeVdwInteraction);
-                            fScalV[i] = gmx::selectByMask(fScalV[i], computeVdwInteraction);
+                            vVdw[i] = gmx::selectByMask(vVdw[i], computeVdwInteraction);
+                            if constexpr (computeForces)
+                            {
+                                fScalV[i] = gmx::selectByMask(fScalV[i], computeVdwInteraction);
+                            }
                         }
 
-                        /* fScalC (and fScalV) now contain: dV/drC * rC
-                         * Now we multiply by rC^-p, so it will be: dV/drC * rC^1-p
-                         * Further down we first multiply by r^p-2 and then by
-                         * the vector r, which in total gives: dV/drC * (r/rC)^1-p
-                         */
-                        fScalC[i] = fScalC[i] * rPInvC;
-                        fScalV[i] = fScalV[i] * rPInvV;
+                        if constexpr (computeForces)
+                        {
+                            /* fScalC (and fScalV) now contain: dV/drC * rC
+                             * Now we multiply by rC^-6, so it will be: dV/drC * rC^-5
+                             * Further down we first multiply by r^4 and then by
+                             * the vector r, which in total gives: dV/drC * (r/rC)^-5
+                             */
+                            fScalC[i] = fScalC[i] * rPInvC;
+                            fScalV[i] = fScalV[i] * rPInvV;
+                        }
                     } // end of block requiring nonZeroState
                 }     // end for (int i = 0; i < NSTATES; i++)
 
@@ -798,8 +858,11 @@ static void nb_free_energy_kernel(const t_nblist&
                         vCTot = vCTot + LFC[i] * vCoul[i];
                         vVTot = vVTot + LFV[i] * vVdw[i];
 
-                        fScal = fScal + LFC[i] * fScalC[i] * rpm2;
-                        fScal = fScal + LFV[i] * fScalV[i] * rpm2;
+                        if constexpr (computeForces)
+                        {
+                            fScal = fScal + LFC[i] * fScalC[i] * rpm2;
+                            fScal = fScal + LFV[i] * fScalV[i] * rpm2;
+                        }
 
                         if constexpr (useSoftCore)
                         {
@@ -858,8 +921,11 @@ static void nb_free_energy_kernel(const t_nblist&
                  */
                 RealType v_lr, f_lr;
 
-                pmeCoulombCorrectionVF(rSq, ewaldBeta, &v_lr, &f_lr, rSqValid);
-                f_lr = f_lr * rInv * rInv;
+                pmeCoulombCorrectionVF<computeForces>(rSq, ewaldBeta, &v_lr, &f_lr);
+                if constexpr (computeForces)
+                {
+                    f_lr = f_lr * rInv * rInv;
+                }
 
                 /* Note that any possible Ewald shift has already been applied in
                  * the normal interaction part above.
@@ -875,7 +941,10 @@ static void nb_free_energy_kernel(const t_nblist&
                 for (int i = 0; i < NSTATES; i++)
                 {
                     vCTot = vCTot - gmx::selectByMask(LFC[i] * qq[i] * v_lr, computeElecEwaldInteraction);
-                    fScal = fScal - gmx::selectByMask(LFC[i] * qq[i] * f_lr, computeElecEwaldInteraction);
+                    if constexpr (computeForces)
+                    {
+                        fScal = fScal - gmx::selectByMask(LFC[i] * qq[i] * f_lr, computeElecEwaldInteraction);
+                    }
                     dvdlCoul = dvdlCoul
                                - gmx::selectByMask(DLF[i] * qq[i] * v_lr, computeElecEwaldInteraction);
                 }
@@ -894,19 +963,22 @@ static void nb_free_energy_kernel(const t_nblist&
                  */
 
                 RealType v_lr, f_lr;
-                pmeLJCorrectionVF(
+                pmeLJCorrectionVF<computeForces>(
                         rInv, rSq, ewaldLJCoeffSq, ewaldLJCoeffSixDivSix, &v_lr, &f_lr, computeVdwEwaldInteraction, bIiEqJnr);
                 v_lr = v_lr * oneSixth;
 
                 for (int i = 0; i < NSTATES; i++)
                 {
                     vVTot = vVTot + gmx::selectByMask(LFV[i] * ljPmeC6Grid[i] * v_lr, computeVdwEwaldInteraction);
-                    fScal = fScal + gmx::selectByMask(LFV[i] * ljPmeC6Grid[i] * f_lr, computeVdwEwaldInteraction);
+                    if constexpr (computeForces)
+                    {
+                        fScal = fScal + gmx::selectByMask(LFV[i] * ljPmeC6Grid[i] * f_lr, computeVdwEwaldInteraction);
+                    }
                     dvdlVdw = dvdlVdw + gmx::selectByMask(DLF[i] * ljPmeC6Grid[i] * v_lr, computeVdwEwaldInteraction);
                 }
             }
 
-            if (doForces && gmx::anyTrue(fScal != zero))
+            if (computeForces && gmx::anyTrue(fScal != zero))
             {
                 const RealType tX = fScal * dX;
                 const RealType tY = fScal * dY;
@@ -921,14 +993,15 @@ static void nb_free_energy_kernel(const t_nblist&
 
         if (havePairsWithinCutoff)
         {
-            if (doForces)
+            if constexpr (computeForces)
             {
                 gmx::transposeScatterIncrU<3>(forceRealPtr, preloadIi, fIX, fIY, fIZ);
-            }
-            if (doShiftForces)
-            {
-                gmx::transposeScatterIncrU<3>(
-                        reinterpret_cast<real*>(threadForceShiftBuffer), preloadIs, fIX, fIY, fIZ);
+
+                if (doShiftForces)
+                {
+                    gmx::transposeScatterIncrU<3>(
+                            reinterpret_cast<real*>(threadForceShiftBuffer), preloadIs, fIX, fIY, fIZ);
+                }
             }
             if (doPotential)
             {
@@ -990,52 +1063,70 @@ typedef void (*KernelFunction)(const t_nblist&
                                gmx::ArrayRef<real> threadVv,
                                gmx::ArrayRef<real> threadDvdl);
 
-template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
+template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch, bool computeForces>
 static KernelFunction dispatchKernelOnUseSimd(const bool useSimd)
 {
     if (useSimd)
     {
 #if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS && GMX_USE_SIMD_KERNELS
-        return (nb_free_energy_kernel<SimdDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
+        return (nb_free_energy_kernel<SimdDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, computeForces>);
 #else
-        return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
+        return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, computeForces>);
 #endif
     }
     else
     {
-        return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
+        return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, computeForces>);
     }
 }
 
-template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald>
-static KernelFunction dispatchKernelOnVdwModifier(const bool vdwModifierIsPotSwitch, const bool useSimd)
+template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
+static KernelFunction dispatchKernelOnComputeForces(const bool computeForces, const bool useSimd)
 {
-    if (vdwModifierIsPotSwitch)
+    if (computeForces)
     {
-        return (dispatchKernelOnUseSimd<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, true>(
+        return (dispatchKernelOnUseSimd<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, true>(
                 useSimd));
     }
     else
     {
-        return (dispatchKernelOnUseSimd<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, false>(
+        return (dispatchKernelOnUseSimd<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, false>(
                 useSimd));
     }
 }
 
+template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald>
+static KernelFunction dispatchKernelOnVdwModifier(const bool vdwModifierIsPotSwitch,
+                                                  const bool computeForces,
+                                                  const bool useSimd)
+{
+    if (vdwModifierIsPotSwitch)
+    {
+        return (dispatchKernelOnComputeForces<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, true>(
+                computeForces, useSimd));
+    }
+    else
+    {
+        return (dispatchKernelOnComputeForces<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, false>(
+                computeForces, useSimd));
+    }
+}
+
 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald>
 static KernelFunction dispatchKernelOnElecInteractionType(const bool elecInteractionTypeIsEwald,
                                                           const bool vdwModifierIsPotSwitch,
+                                                          const bool computeForces,
                                                           const bool useSimd)
 {
     if (elecInteractionTypeIsEwald)
     {
         return (dispatchKernelOnVdwModifier<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, true>(
-                vdwModifierIsPotSwitch, useSimd));
+                vdwModifierIsPotSwitch, computeForces, useSimd));
     }
     else
     {
         return (dispatchKernelOnVdwModifier<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, false>(
-                vdwModifierIsPotSwitch, useSimd));
+                vdwModifierIsPotSwitch, computeForces, useSimd));
     }
 }
 
@@ -1043,17 +1134,18 @@ template<bool useSoftCore, bool scLambdasOrAlphasDiffer>
 static KernelFunction dispatchKernelOnVdwInteractionType(const bool vdwInteractionTypeIsEwald,
                                                          const bool elecInteractionTypeIsEwald,
                                                          const bool vdwModifierIsPotSwitch,
+                                                         const bool computeForces,
                                                          const bool useSimd)
 {
     if (vdwInteractionTypeIsEwald)
     {
         return (dispatchKernelOnElecInteractionType<useSoftCore, scLambdasOrAlphasDiffer, true>(
-                elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
+                elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, computeForces, useSimd));
     }
     else
     {
         return (dispatchKernelOnElecInteractionType<useSoftCore, scLambdasOrAlphasDiffer, false>(
-                elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
+                elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, computeForces, useSimd));
     }
 }
 
@@ -1062,17 +1154,18 @@ static KernelFunction dispatchKernelOnScLambdasOrAlphasDifference(const bool scL
                                                                   const bool vdwInteractionTypeIsEwald,
                                                                   const bool elecInteractionTypeIsEwald,
                                                                   const bool vdwModifierIsPotSwitch,
+                                                                  const bool computeForces,
                                                                   const bool useSimd)
 {
     if (scLambdasOrAlphasDiffer)
     {
         return (dispatchKernelOnVdwInteractionType<useSoftCore, true>(
-                vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
+                vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, computeForces, useSimd));
     }
     else
     {
         return (dispatchKernelOnVdwInteractionType<useSoftCore, false>(
-                vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
+                vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, computeForces, useSimd));
     }
 }
 
@@ -1080,6 +1173,7 @@ static KernelFunction dispatchKernel(const bool                 scLambdasOrAlpha
                                      const bool                 vdwInteractionTypeIsEwald,
                                      const bool                 elecInteractionTypeIsEwald,
                                      const bool                 vdwModifierIsPotSwitch,
+                                     const bool                 computeForces,
                                      const bool                 useSimd,
                                      const interaction_const_t& ic)
 {
@@ -1089,6 +1183,7 @@ static KernelFunction dispatchKernel(const bool                 scLambdasOrAlpha
                                                                    vdwInteractionTypeIsEwald,
                                                                    elecInteractionTypeIsEwald,
                                                                    vdwModifierIsPotSwitch,
+                                                                   computeForces,
                                                                    useSimd));
     }
     else
@@ -1097,6 +1192,7 @@ static KernelFunction dispatchKernel(const bool                 scLambdasOrAlpha
                                                                   vdwInteractionTypeIsEwald,
                                                                   elecInteractionTypeIsEwald,
                                                                   vdwModifierIsPotSwitch,
+                                                                  computeForces,
                                                                   useSimd));
     }
 }
@@ -1137,6 +1233,7 @@ void gmx_nb_free_energy_kernel(const t_nblist&
     const bool  vdwInteractionTypeIsEwald  = (EVDW_PME(ic.vdwtype));
     const bool  elecInteractionTypeIsEwald = (EEL_PME_EWALD(ic.eeltype));
     const bool  vdwModifierIsPotSwitch     = (ic.vdw_modifier == InteractionModifiers::PotSwitch);
+    const bool  computeForces              = ((flags & GMX_NONBONDED_DO_FORCE) != 0);
     bool        scLambdasOrAlphasDiffer    = true;
 
     if (scParams.alphaCoulomb == 0 && scParams.alphaVdw == 0)
@@ -1158,6 +1255,7 @@ void gmx_nb_free_energy_kernel(const t_nblist&
                                 vdwInteractionTypeIsEwald,
                                 elecInteractionTypeIsEwald,
                                 vdwModifierIsPotSwitch,
+                                computeForces,
                                 useSimd,
                                 ic);
     kernelFunc(nlist,