More SIMD preparations in the FE calculations
authorMagnus Lundborg <lundborg.magnus@gmail.com>
Thu, 7 Nov 2019 15:47:06 +0000 (16:47 +0100)
committerMark Abraham <mark.j.abraham@gmail.com>
Wed, 29 Jan 2020 17:14:12 +0000 (18:14 +0100)
Prepare the launch of the FE kernel using SIMD. So far the kernel
is not modified to actually use SIMD.

Change-Id: Iaad24fc37549b5deaa892655be1d2a7317f65955

src/gromacs/gmxlib/nonbonded/nb_free_energy.cpp

index ecd6aa7917e0df32ad1bfaa8b3e78630897fc5fa..7785ab657a1d11b657fec45a3ceed640dcacda7a 100644 (file)
 #include "gromacs/mdtypes/forceoutput.h"
 #include "gromacs/mdtypes/forcerec.h"
 #include "gromacs/mdtypes/md_enums.h"
+#include "gromacs/simd/simd.h"
 #include "gromacs/utility/fatalerror.h"
 
 
+//! Scalar (non-SIMD) data types.
+struct ScalarDataTypes
+{
+    using RealType                     = real; //!< The data type to use as real.
+    using IntType                      = int;  //!< The data type to use as int.
+    static constexpr int simdRealWidth = 1;    //!< The width of the RealType.
+    static constexpr int simdIntWidth  = 1;    //!< The width of the IntType.
+};
+
+#if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS
+//! SIMD data types.
+struct SimdDataTypes
+{
+    using RealType                     = gmx::SimdReal;         //!< The data type to use as real.
+    using IntType                      = gmx::SimdInt32;        //!< The data type to use as int.
+    static constexpr int simdRealWidth = GMX_SIMD_REAL_WIDTH;   //!< The width of the RealType.
+    static constexpr int simdIntWidth  = GMX_SIMD_FINT32_WIDTH; //!< The width of the IntType.
+};
+#endif
+
 //! Computes r^(1/p) and 1/r^(1/p) for the standard p=6
-static inline void pthRoot(const real r, real* pthRoot, real* invPthRoot)
+template<class RealType>
+static inline void pthRoot(const RealType r, RealType* pthRoot, RealType* invPthRoot)
 {
     *invPthRoot = gmx::invsqrt(std::cbrt(r));
     *pthRoot    = 1 / (*invPthRoot);
 }
 
-static inline real calculateRinv6(const real rinvV)
+template<class RealType>
+static inline RealType calculateRinv6(const RealType rinvV)
 {
-    real rinv6 = rinvV * rinvV;
+    RealType rinv6 = rinvV * rinvV;
     return (rinv6 * rinv6 * rinv6);
 }
 
-static inline real calculateVdw6(const real c6, const real rinv6)
+template<class RealType>
+static inline RealType calculateVdw6(const RealType c6, const RealType rinv6)
 {
     return (c6 * rinv6);
 }
 
-static inline real calculateVdw12(const real c12, const real rinv6)
+template<class RealType>
+static inline RealType calculateVdw12(const RealType c12, const RealType rinv6)
 {
     return (c12 * rinv6 * rinv6);
 }
 
 /* reaction-field electrostatics */
-static inline real
-reactionFieldScalarForce(const real qq, const real rinv, const real r, const real krf, const real two)
+template<class RealType>
+static inline RealType reactionFieldScalarForce(const RealType qq,
+                                                const RealType rinv,
+                                                const RealType r,
+                                                const real     krf,
+                                                const real     two)
 {
     return (qq * (rinv - two * krf * r * r));
 }
-static inline real reactionFieldPotential(const real qq, const real rinv, const real r, const real krf, const real potentialShift)
+template<class RealType>
+static inline RealType reactionFieldPotential(const RealType qq,
+                                              const RealType rinv,
+                                              const RealType r,
+                                              const real     krf,
+                                              const real     potentialShift)
 {
     return (qq * (rinv + krf * r * r - potentialShift));
 }
 
 /* Ewald electrostatics */
-static inline real ewaldScalarForce(const real coulomb, const real rinv)
+template<class RealType>
+static inline RealType ewaldScalarForce(const RealType coulomb, const RealType rinv)
 {
     return (coulomb * rinv);
 }
-static inline real ewaldPotential(const real coulomb, const real rinv, const real potentialShift)
+template<class RealType>
+static inline RealType ewaldPotential(const RealType coulomb, const RealType rinv, const real potentialShift)
 {
     return (coulomb * (rinv - potentialShift));
 }
 
 /* cutoff LJ */
-static inline real lennardJonesScalarForce(const real v6, const real v12)
+template<class RealType>
+static inline RealType lennardJonesScalarForce(const RealType v6, const RealType v12)
 {
     return (v12 - v6);
 }
-static inline real lennardJonesPotential(const real v6,
-                                         const real v12,
-                                         const real c6,
-                                         const real c12,
-                                         const real repulsionShift,
-                                         const real dispersionShift,
-                                         const real onesixth,
-                                         const real onetwelfth)
+template<class RealType>
+static inline RealType lennardJonesPotential(const RealType v6,
+                                             const RealType v12,
+                                             const RealType c6,
+                                             const RealType c12,
+                                             const real     repulsionShift,
+                                             const real     dispersionShift,
+                                             const real     onesixth,
+                                             const real     onetwelfth)
 {
     return ((v12 + c12 * repulsionShift) * onetwelfth - (v6 + c6 * dispersionShift) * onesixth);
 }
@@ -122,13 +160,14 @@ static inline real ewaldLennardJonesGridSubtract(const real c6grid, const real p
 }
 
 /* LJ Potential switch */
-static inline real potSwitchScalarForceMod(const real fScalarInp,
-                                           const real potential,
-                                           const real sw,
-                                           const real r,
-                                           const real rVdw,
-                                           const real dsw,
-                                           const real zero)
+template<class RealType>
+static inline RealType potSwitchScalarForceMod(const RealType fScalarInp,
+                                               const RealType potential,
+                                               const RealType sw,
+                                               const RealType r,
+                                               const RealType rVdw,
+                                               const RealType dsw,
+                                               const real     zero)
 {
     if (r < rVdw)
     {
@@ -137,8 +176,12 @@ static inline real potSwitchScalarForceMod(const real fScalarInp,
     }
     return (zero);
 }
-static inline real
-potSwitchPotentialMod(const real potentialInp, const real sw, const real r, const real rVdw, const real zero)
+template<class RealType>
+static inline RealType potSwitchPotentialMod(const RealType potentialInp,
+                                             const RealType sw,
+                                             const RealType r,
+                                             const RealType rVdw,
+                                             const real     zero)
 {
     if (r < rVdw)
     {
@@ -150,7 +193,7 @@ potSwitchPotentialMod(const real potentialInp, const real sw, const real r, cons
 
 
 //! Templated free-energy non-bonded kernel
-template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
+template<typename DataTypes, bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
 static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                                   rvec* gmx_restrict         xx,
                                   gmx::ForceWithShiftForces* forceWithShiftForces,
@@ -163,6 +206,10 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
 #define STATE_B 1
 #define NSTATES 2
 
+    using RealType = typename DataTypes::RealType;
+    using IntType  = typename DataTypes::IntType;
+
+    /* FIXME: How should these be handled with SIMD? */
     constexpr real onetwelfth = 1.0 / 12.0;
     constexpr real onesixth   = 1.0 / 6.0;
     constexpr real zero       = 0.0;
@@ -339,17 +386,17 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
 
         for (int k = nj0; k < nj1; k++)
         {
-            int        tj[NSTATES];
-            const int  jnr = jjnr[k];
-            const int  j3  = 3 * jnr;
-            real       c6[NSTATES], c12[NSTATES], qq[NSTATES], Vcoul[NSTATES], Vvdw[NSTATES];
-            real       r, rinv, rp, rpm2;
-            real       alpha_vdw_eff, alpha_coul_eff, sigma6[NSTATES];
-            const real dx  = ix - x[j3];
-            const real dy  = iy - x[j3 + 1];
-            const real dz  = iz - x[j3 + 2];
-            const real rsq = dx * dx + dy * dy + dz * dz;
-            real       FscalC[NSTATES], FscalV[NSTATES];
+            int            tj[NSTATES];
+            const int      jnr = jjnr[k];
+            const int      j3  = 3 * jnr;
+            RealType       c6[NSTATES], c12[NSTATES], qq[NSTATES], Vcoul[NSTATES], Vvdw[NSTATES];
+            RealType       r, rinv, rp, rpm2;
+            RealType       alpha_vdw_eff, alpha_coul_eff, sigma6[NSTATES];
+            const RealType dx  = ix - x[j3];
+            const RealType dy  = iy - x[j3 + 1];
+            const RealType dz  = iz - x[j3 + 2];
+            const RealType rsq = dx * dx + dy * dy + dz * dz;
+            RealType       FscalC[NSTATES], FscalV[NSTATES];
 
             if (rsq >= rcutoff_max2)
             {
@@ -398,7 +445,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                 rp   = 1;
             }
 
-            real Fscal = 0;
+            RealType Fscal = 0;
 
             qq[STATE_A] = iqA * chargeA[jnr];
             qq[STATE_B] = iqB * chargeB[jnr];
@@ -454,7 +501,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                     Vcoul[i]  = 0;
                     Vvdw[i]   = 0;
 
-                    real rinvC, rinvV, rC, rV, rpinvC, rpinvV;
+                    RealType rinvC, rinvV, rC, rV, rpinvC, rpinvV;
 
                     /* Only spend time on A or B state if it is non-zero */
                     if ((qq[i] != 0) || (c6[i] != 0) || (c12[i] != 0))
@@ -518,7 +565,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                                                      || (!vdwInteractionTypeIsEwald && rV < rvdw);
                         if ((c6[i] != 0 || c12[i] != 0) && computeVdwInteraction)
                         {
-                            real rinv6;
+                            RealType rinv6;
                             if (useSoftCore)
                             {
                                 rinv6 = rpinvV;
@@ -527,8 +574,8 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                             {
                                 rinv6 = calculateRinv6(rinvV);
                             }
-                            real Vvdw6  = calculateVdw6(c6[i], rinv6);
-                            real Vvdw12 = calculateVdw12(c12[i], rinv6);
+                            RealType Vvdw6  = calculateVdw6(c6[i], rinv6);
+                            RealType Vvdw12 = calculateVdw12(c12[i], rinv6);
 
                             Vvdw[i] = lennardJonesPotential(Vvdw6, Vvdw12, c6[i], c12[i], repulsionShift,
                                                             dispersionShift, onesixth, onetwelfth);
@@ -543,11 +590,12 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
 
                             if (vdwModifierIsPotSwitch)
                             {
-                                real d        = rV - ic->rvdw_switch;
-                                d             = (d > zero) ? d : zero;
-                                const real d2 = d * d;
-                                const real sw = one + d2 * d * (vdw_swV3 + d * (vdw_swV4 + d * vdw_swV5));
-                                const real dsw = d2 * (vdw_swF2 + d * (vdw_swF3 + d * vdw_swF4));
+                                RealType d        = rV - ic->rvdw_switch;
+                                d                 = (d > zero) ? d : zero;
+                                const RealType d2 = d * d;
+                                const RealType sw =
+                                        one + d2 * d * (vdw_swV3 + d * (vdw_swV4 + d * vdw_swV5));
+                                const RealType dsw = d2 * (vdw_swF2 + d * (vdw_swF3 + d * vdw_swF4));
 
                                 FscalV[i] = potSwitchScalarForceMod(FscalV[i], Vvdw[i], sw, rV,
                                                                     rvdw, dsw, zero);
@@ -595,7 +643,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                  * As there is no singularity, there is no need for soft-core.
                  */
                 const real FF = -two * krf;
-                real       VV = krf * rsq - crf;
+                RealType   VV = krf * rsq - crf;
 
                 if (ii == jnr)
                 {
@@ -622,11 +670,11 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                  */
                 real v_lr, f_lr;
 
-                const real ewrt   = r * ewtabscale;
-                int        ewitab = static_cast<int>(ewrt);
-                const real eweps  = ewrt - ewitab;
-                ewitab            = 4 * ewitab;
-                f_lr              = ewtab[ewitab] + eweps * ewtab[ewitab + 1];
+                const RealType ewrt   = r * ewtabscale;
+                IntType        ewitab = static_cast<IntType>(ewrt);
+                const RealType eweps  = ewrt - ewitab;
+                ewitab                = 4 * ewitab;
+                f_lr                  = ewtab[ewitab] + eweps * ewtab[ewitab + 1];
                 v_lr = (ewtab[ewitab + 2] - ewtabhalfspace * eweps * (ewtab[ewitab] + f_lr));
                 f_lr *= rinv;
 
@@ -667,15 +715,16 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                  * r close to 0 for non-interacting pairs.
                  */
 
-                const real rs   = rsq * rinv * ewtabscale;
-                const int  ri   = static_cast<int>(rs);
-                const real frac = rs - ri;
-                const real f_lr = (1 - frac) * tab_ewald_F_lj[ri] + frac * tab_ewald_F_lj[ri + 1];
+                const RealType rs   = rsq * rinv * ewtabscale;
+                const IntType  ri   = static_cast<IntType>(rs);
+                const RealType frac = rs - ri;
+                const RealType f_lr = (1 - frac) * tab_ewald_F_lj[ri] + frac * tab_ewald_F_lj[ri + 1];
                 /* TODO: Currently the Ewald LJ table does not contain
                  * the factor 1/6, we should add this.
                  */
-                const real FF = f_lr * rinv / six;
-                real VV = (tab_ewald_V_lj[ri] - ewtabhalfspace * frac * (tab_ewald_F_lj[ri] + f_lr)) / six;
+                const RealType FF = f_lr * rinv / six;
+                RealType       VV =
+                        (tab_ewald_V_lj[ri] - ewtabhalfspace * frac * (tab_ewald_F_lj[ri] + f_lr)) / six;
 
                 if (ii == jnr)
                 {
@@ -777,51 +826,74 @@ typedef void (*KernelFunction)(const t_nblist* gmx_restrict nlist,
                                nb_kernel_data_t* gmx_restrict kernel_data,
                                t_nrnb* gmx_restrict nrnb);
 
+template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
+static KernelFunction dispatchKernelOnUseSimd(const bool useSimd)
+{
+    if (useSimd)
+    {
+#if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS
+        /* FIXME: Here SimdDataTypes should be used to enable SIMD. So far, the code in nb_free_energy_kernel is not adapted to SIMD */
+        return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald,
+                                      elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
+#else
+        return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald,
+                                      elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
+#endif
+    }
+    else
+    {
+        return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald,
+                                      elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
+    }
+}
+
 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald>
-static KernelFunction dispatchKernelOnVdwModifier(const bool vdwModifierIsPotSwitch)
+static KernelFunction dispatchKernelOnVdwModifier(const bool vdwModifierIsPotSwitch, const bool useSimd)
 {
     if (vdwModifierIsPotSwitch)
     {
-        return (nb_free_energy_kernel<useSoftCore, scLambdasOrAlphasDiffer,
-                                      vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, true>);
+        return (dispatchKernelOnUseSimd<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald,
+                                        elecInteractionTypeIsEwald, true>(useSimd));
     }
     else
     {
-        return (nb_free_energy_kernel<useSoftCore, scLambdasOrAlphasDiffer,
-                                      vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, false>);
+        return (dispatchKernelOnUseSimd<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald,
+                                        elecInteractionTypeIsEwald, false>(useSimd));
     }
 }
 
 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald>
 static KernelFunction dispatchKernelOnElecInteractionType(const bool elecInteractionTypeIsEwald,
-                                                          const bool vdwModifierIsPotSwitch)
+                                                          const bool vdwModifierIsPotSwitch,
+                                                          const bool useSimd)
 {
     if (elecInteractionTypeIsEwald)
     {
         return (dispatchKernelOnVdwModifier<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, true>(
-                vdwModifierIsPotSwitch));
+                vdwModifierIsPotSwitch, useSimd));
     }
     else
     {
         return (dispatchKernelOnVdwModifier<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, false>(
-                vdwModifierIsPotSwitch));
+                vdwModifierIsPotSwitch, useSimd));
     }
 }
 
 template<bool useSoftCore, bool scLambdasOrAlphasDiffer>
 static KernelFunction dispatchKernelOnVdwInteractionType(const bool vdwInteractionTypeIsEwald,
                                                          const bool elecInteractionTypeIsEwald,
-                                                         const bool vdwModifierIsPotSwitch)
+                                                         const bool vdwModifierIsPotSwitch,
+                                                         const bool useSimd)
 {
     if (vdwInteractionTypeIsEwald)
     {
         return (dispatchKernelOnElecInteractionType<useSoftCore, scLambdasOrAlphasDiffer, true>(
-                elecInteractionTypeIsEwald, vdwModifierIsPotSwitch));
+                elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
     }
     else
     {
         return (dispatchKernelOnElecInteractionType<useSoftCore, scLambdasOrAlphasDiffer, false>(
-                elecInteractionTypeIsEwald, vdwModifierIsPotSwitch));
+                elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
     }
 }
 
@@ -829,17 +901,18 @@ template<bool useSoftCore>
 static KernelFunction dispatchKernelOnScLambdasOrAlphasDifference(const bool scLambdasOrAlphasDiffer,
                                                                   const bool vdwInteractionTypeIsEwald,
                                                                   const bool elecInteractionTypeIsEwald,
-                                                                  const bool vdwModifierIsPotSwitch)
+                                                                  const bool vdwModifierIsPotSwitch,
+                                                                  const bool useSimd)
 {
     if (scLambdasOrAlphasDiffer)
     {
         return (dispatchKernelOnVdwInteractionType<useSoftCore, true>(
-                vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch));
+                vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
     }
     else
     {
         return (dispatchKernelOnVdwInteractionType<useSoftCore, false>(
-                vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch));
+                vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
     }
 }
 
@@ -847,19 +920,20 @@ static KernelFunction dispatchKernel(const bool        scLambdasOrAlphasDiffer,
                                      const bool        vdwInteractionTypeIsEwald,
                                      const bool        elecInteractionTypeIsEwald,
                                      const bool        vdwModifierIsPotSwitch,
+                                     const bool        useSimd,
                                      const t_forcerec* fr)
 {
     if (fr->sc_alphacoul == 0 && fr->sc_alphavdw == 0)
     {
         return (dispatchKernelOnScLambdasOrAlphasDifference<false>(
                 scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald,
-                vdwModifierIsPotSwitch));
+                vdwModifierIsPotSwitch, useSimd));
     }
     else
     {
         return (dispatchKernelOnScLambdasOrAlphasDifference<true>(
                 scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald,
-                vdwModifierIsPotSwitch));
+                vdwModifierIsPotSwitch, useSimd));
     }
 }
 
@@ -879,6 +953,7 @@ void gmx_nb_free_energy_kernel(const t_nblist*            nlist,
     const bool elecInteractionTypeIsEwald = (EEL_PME_EWALD(fr->ic->eeltype));
     const bool vdwModifierIsPotSwitch     = (fr->ic->vdw_modifier == eintmodPOTSWITCH);
     bool       scLambdasOrAlphasDiffer    = true;
+    const bool useSimd                    = fr->use_simd_kernels;
 
     if (fr->sc_alphacoul == 0 && fr->sc_alphavdw == 0)
     {
@@ -895,7 +970,9 @@ void gmx_nb_free_energy_kernel(const t_nblist*            nlist,
     {
         GMX_RELEASE_ASSERT(false, "Unsupported soft-core r-power");
     }
-    KernelFunction kernelFunc = dispatchKernel(scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald,
-                                               elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, fr);
+
+    KernelFunction kernelFunc;
+    kernelFunc = dispatchKernel(scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald,
+                                elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd, fr);
     kernelFunc(nlist, xx, ff, fr, mdatoms, kernel_data, nrnb);
 }