Simplify LJ parameter lookup
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
index a372afa77542b646d53bf6a2e413db20558a7f3b..504557d5630c8e14da7bed76015157edae476bb4 100644 (file)
@@ -100,15 +100,14 @@ using cl::sycl::access::fence_space;
 using cl::sycl::access::mode;
 using cl::sycl::access::target;
 
-static inline void convertSigmaEpsilonToC6C12(const float                  sigma,
-                                              const float                  epsilon,
-                                              cl::sycl::private_ptr<float> c6,
-                                              cl::sycl::private_ptr<float> c12)
+static inline Float2 convertSigmaEpsilonToC6C12(const float sigma, const float epsilon)
 {
     const float sigma2 = sigma * sigma;
     const float sigma6 = sigma2 * sigma2 * sigma2;
-    *c6                = epsilon * sigma6;
-    *c12               = (*c6) * sigma6;
+    const float c6     = epsilon * sigma6;
+    const float c12    = c6 * sigma6;
+
+    return Float2(c6, c12);
 }
 
 template<bool doCalcEnergies>
@@ -147,25 +146,23 @@ static inline void ljForceSwitch(const shift_consts_t         dispersionShift,
 
 //! \brief Fetch C6 grid contribution coefficients and return the product of these.
 template<enum VdwType vdwType>
-static inline float calculateLJEwaldC6Grid(const DeviceAccessor<float, mode::read> a_nbfpComb,
-                                           const int                               typeI,
-                                           const int                               typeJ)
+static inline float calculateLJEwaldC6Grid(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
+                                           const int                                typeI,
+                                           const int                                typeJ)
 {
     if constexpr (vdwType == VdwType::EwaldGeom)
     {
-        return a_nbfpComb[2 * typeI] * a_nbfpComb[2 * typeJ];
+        return a_nbfpComb[typeI][0] * a_nbfpComb[typeJ][0];
     }
     else
     {
         static_assert(vdwType == VdwType::EwaldLB);
         /* sigma and epsilon are scaled to give 6*C6 */
-        const float c6_i  = a_nbfpComb[2 * typeI];
-        const float c12_i = a_nbfpComb[2 * typeI + 1];
-        const float c6_j  = a_nbfpComb[2 * typeJ];
-        const float c12_j = a_nbfpComb[2 * typeJ + 1];
+        const Float2 c6c12_i = a_nbfpComb[typeI];
+        const Float2 c6c12_j = a_nbfpComb[typeJ];
 
-        const float sigma   = c6_i + c6_j;
-        const float epsilon = c12_i * c12_j;
+        const float sigma   = c6c12_i[0] + c6c12_j[0];
+        const float epsilon = c6c12_i[1] * c6c12_j[1];
 
         const float sigma2 = sigma * sigma;
         return epsilon * sigma2 * sigma2 * sigma2;
@@ -174,17 +171,17 @@ static inline float calculateLJEwaldC6Grid(const DeviceAccessor<float, mode::rea
 
 //! Calculate LJ-PME grid force contribution with geometric or LB combination rule.
 template<bool doCalcEnergies, enum VdwType vdwType>
-static inline void ljEwaldComb(const DeviceAccessor<float, mode::read> a_nbfpComb,
-                               const float                             sh_lj_ewald,
-                               const int                               typeI,
-                               const int                               typeJ,
-                               const float                             r2,
-                               const float                             r2Inv,
-                               const float                             lje_coeff2,
-                               const float                             lje_coeff6_6,
-                               const float                             int_bit,
-                               cl::sycl::private_ptr<float>            fInvR,
-                               cl::sycl::private_ptr<float>            eLJ)
+static inline void ljEwaldComb(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
+                               const float                              sh_lj_ewald,
+                               const int                                typeI,
+                               const int                                typeJ,
+                               const float                              r2,
+                               const float                              r2Inv,
+                               const float                              lje_coeff2,
+                               const float                              lje_coeff6_6,
+                               const float                              int_bit,
+                               cl::sycl::private_ptr<float>             fInvR,
+                               cl::sycl::private_ptr<float>             eLJ)
 {
     const float c6grid = calculateLJEwaldC6Grid<vdwType>(a_nbfpComb, typeI, typeJ);
 
@@ -429,8 +426,8 @@ auto nbnxmKernel(cl::sycl::handler&                                   cgh,
                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
                  OptionalAccessor<Float2, mode::read, ljComb<vdwType>>       a_ljComb,
                  OptionalAccessor<int, mode::read, !ljComb<vdwType>>         a_atomTypes,
-                 OptionalAccessor<float, mode::read, !ljComb<vdwType>>       a_nbfp,
-                 OptionalAccessor<float, mode::read, ljEwald<vdwType>>       a_nbfpComb,
+                 OptionalAccessor<Float2, mode::read, !ljComb<vdwType>>      a_nbfp,
+                 OptionalAccessor<Float2, mode::read, ljEwald<vdwType>>      a_nbfpComb,
                  OptionalAccessor<float, mode::read, elecEwaldTab<elecType>> a_coulombTab,
                  const int                                                   numTypes,
                  const float                                                 rCoulombSq,
@@ -609,7 +606,7 @@ auto nbnxmKernel(cl::sycl::handler&                                   cgh,
                     {
                         energyVdw +=
                                 a_nbfp[a_atomTypes[(sci * c_nbnxnGpuNumClusterPerSupercluster + i) * c_clSize + tidxi]
-                                       * (numTypes + 1) * 2];
+                                       * (numTypes + 1)][0];
                     }
                 }
                 /* divide the self term(s) equally over the j-threads, then multiply with the coefficients. */
@@ -712,23 +709,21 @@ auto nbnxmKernel(cl::sycl::handler&                                   cgh,
                         {
                             const float qi = xqi[3];
                             int         atomTypeI; // Only needed if (!props.vdwComb)
-                            float       c6, c12, sigma, epsilon;
+                            float       sigma, epsilon;
+                            Float2      c6c12;
 
                             if constexpr (!props.vdwComb)
                             {
                                 /* LJ 6*C6 and 12*C12 */
-                                atomTypeI     = sm_atomTypeI[i][tidxi];
-                                const int idx = (numTypes * atomTypeI + atomTypeJ) * 2;
-                                c6            = a_nbfp[idx]; // TODO: Make a_nbfm into float2
-                                c12           = a_nbfp[idx + 1];
+                                atomTypeI = sm_atomTypeI[i][tidxi];
+                                c6c12     = a_nbfp[numTypes * atomTypeI + atomTypeJ];
                             }
                             else
                             {
                                 const Float2 ljCombI = sm_ljCombI[i][tidxi];
                                 if constexpr (props.vdwCombGeom)
                                 {
-                                    c6  = ljCombI[0] * ljCombJ[0];
-                                    c12 = ljCombI[1] * ljCombJ[1];
+                                    c6c12 = Float2(ljCombI[0] * ljCombJ[0], ljCombI[1] * ljCombJ[1]);
                                 }
                                 else
                                 {
@@ -738,11 +733,15 @@ auto nbnxmKernel(cl::sycl::handler&                                   cgh,
                                     epsilon = ljCombI[1] * ljCombJ[1];
                                     if constexpr (doCalcEnergies)
                                     {
-                                        convertSigmaEpsilonToC6C12(sigma, epsilon, &c6, &c12);
+                                        c6c12 = convertSigmaEpsilonToC6C12(sigma, epsilon);
                                     }
                                 } // props.vdwCombGeom
                             }     // !props.vdwComb
 
+                            // c6 and c12 are unused and garbage iff props.vdwCombLB && !doCalcEnergies
+                            const float c6  = c6c12[0];
+                            const float c12 = c6c12[1];
+
                             // Ensure distance do not become so small that r^-12 overflows
                             r2 = std::max(r2, c_nbnxnMinDistanceSquared);
 #if GMX_SYCL_HIPSYCL