Inline functions of calculations of FE interaction types.
authorMagnus Lundborg <lundborg.magnus@gmail.com>
Wed, 19 Jun 2019 14:13:33 +0000 (16:13 +0200)
committerMark Abraham <mark.j.abraham@gmail.com>
Wed, 4 Sep 2019 10:14:21 +0000 (12:14 +0200)
The functions are so far only templated on the real
requirements of the softcore that is used.

This is one step towards templating the calculations
for SIMD.

Refs #2997.

Change-Id: I3fd119dce30f95eba4f8cd6f139f99260acf0e22

src/gromacs/gmxlib/nonbonded/nb_free_energy.cpp

index 8303ae6007348e5ae6e22735054521d68f366716..cd821ef41411c034f2e9ccea6e145c83d2fce6e8 100644 (file)
@@ -110,6 +110,119 @@ inline void pthRoot<SoftCoreTreatment::RPower48>(const double  r,
     *invPthRoot = 1/(*pthRoot);
 }
 
+template<SoftCoreTreatment softCoreTreatment>
+static inline real calculateSigmaPow(const real sigma6)
+{
+    if (softCoreTreatment == SoftCoreTreatment::RPower6)
+    {
+        return sigma6;
+    }
+    else
+    {
+        real sigmaPow = sigma6 * sigma6;     /* sigma^12 */
+        sigmaPow      = sigmaPow * sigmaPow; /* sigma^24 */
+        sigmaPow      = sigmaPow * sigmaPow; /* sigma^48 */
+        return (sigmaPow);
+    }
+}
+
+template<SoftCoreTreatment softCoreTreatment, class SCReal>
+static inline real calculateRinv6(const SCReal rinvV)
+{
+    if (softCoreTreatment == SoftCoreTreatment::RPower6)
+    {
+        return rinvV;
+    }
+    else
+    {
+        real rinv6 = rinvV * rinvV;
+        return (rinv6 * rinv6 * rinv6);
+    }
+}
+
+static inline real calculateVdw6(const real c6, const real rinv6)
+{
+    return (c6 * rinv6);
+}
+
+static inline real calculateVdw12(const real c12, const real rinv6)
+{
+    return (c12 * rinv6 * rinv6);
+}
+
+/* reaction-field electrostatics */
+template<class SCReal>
+static inline SCReal
+reactionFieldScalarForce(const real qq, const real rinv, const SCReal r, const real krf,
+                         const real two)
+{
+    return (qq*(rinv - two*krf*r*r));
+}
+template<class SCReal>
+static inline real
+reactionFieldPotential(const real qq, const real rinv, const SCReal r, const real krf,
+                       const real potentialShift)
+{
+    return (qq*(rinv + krf*r*r-potentialShift));
+}
+
+/* Ewald electrostatics */
+static inline real
+ewaldScalarForce(const real coulomb, const real rinv)
+{
+    return (coulomb*rinv);
+}
+static inline real
+ewaldPotential(const real coulomb, const real rinv, const real potentialShift)
+{
+    return (coulomb*(rinv-potentialShift));
+}
+
+/* cutoff LJ */
+static inline real
+lennardJonesScalarForce(const real v6, const real v12)
+{
+    return (v12 - v6);
+}
+static inline real
+lennardJonesPotential(const real v6, const real v12, const real c6, const real c12, const real repulsionShift,
+                      const real dispersionShift, const real onesixth, const real onetwelfth)
+{
+    return ((v12 + c12*repulsionShift)*onetwelfth - (v6 + c6*dispersionShift)*onesixth);
+}
+
+/* Ewald LJ */
+static inline real
+ewaldLennardJonesGridSubtract(const real c6grid, const real potentialShift, const real onesixth)
+{
+    return (c6grid * potentialShift * onesixth);
+}
+
+/* LJ Potential switch */
+template<class SCReal>
+static inline SCReal
+potSwitchScalarForceMod(const SCReal fScalarInp, const real potential, const real sw, const SCReal r, const real rVdw, const real dsw, const real zero)
+{
+    if (r < rVdw)
+    {
+        SCReal fScalar = fScalarInp * sw - r * potential * dsw;
+        return (fScalar);
+    }
+    return (zero);
+}
+template<class SCReal>
+static inline real
+potSwitchPotentialMod(const real potentialInp, const real sw, const SCReal r, const real rVdw, const real zero)
+{
+    if (r < rVdw)
+    {
+        real potential = potentialInp * sw;
+        return (potential);
+    }
+    return (zero);
+}
+
+
 //! Templated free-energy non-bonded kernel
 template<SoftCoreTreatment softCoreTreatment, bool scLambdasOrAlphasDiffer>
 static void
@@ -133,13 +246,12 @@ nb_free_energy_kernel(const t_nblist * gmx_restrict    nlist,
     real          tx, ty, tz, Fscal;
     SCReal        FscalC[NSTATES], FscalV[NSTATES];  /* Needs double for sc_power==48 */
     real          Vcoul[NSTATES], Vvdw[NSTATES];
-    real          rinv6, r;
     real          iqA, iqB;
     real          qq[NSTATES], vctot;
     int           ntiA, ntiB, tj[NSTATES];
-    real          Vvdw6, Vvdw12, vvtot;
+    real          vvtot;
     real          ix, iy, iz, fix, fiy, fiz;
-    real          dx, dy, dz, rsq, rinv;
+    real          dx, dy, dz, r, rsq, rinv;
     real          c6[NSTATES], c12[NSTATES], c6grid;
     real          LFC[NSTATES], LFV[NSTATES], DLF[NSTATES];
     SCReal        dvdl_coul, dvdl_vdw;
@@ -173,7 +285,6 @@ nb_free_energy_kernel(const t_nblist * gmx_restrict    nlist,
     real          rcutoff_max2;
     const real *  tab_ewald_F_lj = nullptr;
     const real *  tab_ewald_V_lj = nullptr;
-    real          d, d2, sw, dsw;
     real          vdw_swV3, vdw_swV4, vdw_swV5, vdw_swF2, vdw_swF3, vdw_swF4;
     gmx_bool      bComputeVdwInteraction, bComputeElecInteraction;
     const real *  ewtab = nullptr;
@@ -243,7 +354,7 @@ nb_free_energy_kernel(const t_nblist * gmx_restrict    nlist,
 
     if (ic->vdw_modifier == eintmodPOTSWITCH)
     {
-        d               = ic->rvdw - ic->rvdw_switch;
+        real d          = ic->rvdw - ic->rvdw_switch;
         vdw_swV3        = -10.0/(d*d*d);
         vdw_swV4        =  15.0/(d*d*d*d);
         vdw_swV5        =  -6.0/(d*d*d*d*d);
@@ -458,16 +569,7 @@ nb_free_energy_kernel(const t_nblist * gmx_restrict    nlist,
                         {
                             sigma6[i]       = sigma6_def;
                         }
-                        if (softCoreTreatment == SoftCoreTreatment::RPower6)
-                        {
-                            sigma_pow[i]    = sigma6[i];
-                        }
-                        else
-                        {
-                            sigma_pow[i]    = sigma6[i]*sigma6[i];       /* sigma^12 */
-                            sigma_pow[i]    = sigma_pow[i]*sigma_pow[i]; /* sigma^24 */
-                            sigma_pow[i]    = sigma_pow[i]*sigma_pow[i]; /* sigma^48 */
-                        }
+                        sigma_pow[i] = calculateSigmaPow<softCoreTreatment>(sigma6[i]);
                     }
                 }
 
@@ -538,15 +640,13 @@ nb_free_energy_kernel(const t_nblist * gmx_restrict    nlist,
                         {
                             if (bEwald)
                             {
-                                /* Ewald FEP is done only on the 1/r part */
-                                Vcoul[i]   = qq[i]*(rinvC-sh_ewald);
-                                FscalC[i]  = qq[i]*rinvC;
+                                Vcoul[i]  = ewaldPotential(qq[i], rinvC, sh_ewald);
+                                FscalC[i] = ewaldScalarForce(qq[i], rinvC);
                             }
                             else
                             {
-                                /* reaction-field */
-                                Vcoul[i]   = qq[i]*(rinvC + krf*rC*rC-crf);
-                                FscalC[i]  = qq[i]*(rinvC - two*krf*rC*rC);
+                                Vcoul[i]  = reactionFieldPotential(qq[i], rinvC, rC, krf, crf);
+                                FscalC[i] = reactionFieldScalarForce(qq[i], rinvC, rC, krf, two);
                             }
                         }
 
@@ -560,43 +660,37 @@ nb_free_energy_kernel(const t_nblist * gmx_restrict    nlist,
                             (!bEwaldLJ && rV < rvdw);
                         if ((c6[i] != 0 || c12[i] != 0) && bComputeVdwInteraction)
                         {
-                            /* cutoff LJ, also handles part of Ewald LJ */
+                            real rinv6;
                             if (softCoreTreatment == SoftCoreTreatment::RPower6)
                             {
-                                rinv6        = rpinvV;
+                                rinv6  = calculateRinv6<softCoreTreatment>(rpinvV);
                             }
                             else
                             {
-                                rinv6        = rinvV*rinvV;
-                                rinv6        = rinv6*rinv6*rinv6;
+                                rinv6  = calculateRinv6<softCoreTreatment>(rinvV);
                             }
-                            Vvdw6            = c6[i]*rinv6;
-                            Vvdw12           = c12[i]*rinv6*rinv6;
+                            real Vvdw6  = calculateVdw6(c6[i], rinv6);
+                            real Vvdw12 = calculateVdw12(c12[i], rinv6);
 
-                            Vvdw[i]          = ( (Vvdw12 + c12[i]*repulsionShift)*onetwelfth
-                                                 - (Vvdw6 + c6[i]*dispersionShift)*onesixth);
-                            FscalV[i]        = Vvdw12 - Vvdw6;
+                            Vvdw[i]     = lennardJonesPotential(Vvdw6, Vvdw12, c6[i], c12[i], repulsionShift, dispersionShift, onesixth, onetwelfth);
+                            FscalV[i]   = lennardJonesScalarForce(Vvdw6, Vvdw12);
 
                             if (bEwaldLJ)
                             {
                                 /* Subtract the grid potential at the cut-off */
-                                c6grid      = nbfp_grid[tj[i]];
-                                Vvdw[i]    += c6grid*sh_lj_ewald*onesixth;
+                                Vvdw[i] += ewaldLennardJonesGridSubtract(nbfp_grid[tj[i]], sh_lj_ewald, onesixth);
                             }
 
                             if (ic->vdw_modifier == eintmodPOTSWITCH)
                             {
-                                d                = rV - ic->rvdw_switch;
-                                d                = (d > zero) ? d : zero;
-                                d2               = d*d;
-                                sw               = one+d2*d*(vdw_swV3+d*(vdw_swV4+d*vdw_swV5));
-                                dsw              = d2*(vdw_swF2+d*(vdw_swF3+d*vdw_swF4));
-
-                                FscalV[i]        = FscalV[i]*sw - rV*Vvdw[i]*dsw;
-                                Vvdw[i]         *= sw;
-
-                                FscalV[i]  = (rV < rvdw) ? FscalV[i] : zero;
-                                Vvdw[i]    = (rV < rvdw) ? Vvdw[i] : zero;
+                                real d    = rV - ic->rvdw_switch;
+                                d         = (d > zero) ? d : zero;
+                                real d2   = d*d;
+                                real sw   = one+d2*d*(vdw_swV3+d*(vdw_swV4+d*vdw_swV5));
+                                real dsw  = d2*(vdw_swF2+d*(vdw_swF3+d*vdw_swF4));
+
+                                FscalV[i] = potSwitchScalarForceMod(FscalV[i], Vvdw[i], sw, rV, rvdw, dsw, zero);
+                                Vvdw[i]   = potSwitchPotentialMod(Vvdw[i], sw, rV, rvdw, zero);
                             }
                         }