Simplify LJ parameter lookup
[alexxy/gromacs.git] / src / gromacs / nbnxm / opencl / nbnxm_ocl_data_mgmt.cpp
index a2ee20416f4684cfd6468e906a6380466c853a7b..0583d0a73d7d6896aa7ebacf2febbfeaea4b9717 100644 (file)
@@ -163,19 +163,28 @@ static void init_nbparam(NBParamGpu*                     nbp,
         allocateDeviceBuffer(&nbp->coulomb_tab, 1, deviceContext);
     }
 
-    const int nnbfp      = 2 * nbatParams.numTypes * nbatParams.numTypes;
-    const int nnbfp_comb = 2 * nbatParams.numTypes;
-
     {
         /* set up LJ parameter lookup table */
-        DeviceBuffer<real> nbfp;
-        initParamLookupTable(&nbfp, nullptr, nbatParams.nbfp.data(), nnbfp, deviceContext);
+        static_assert(sizeof(Float2) == 2 * sizeof(decltype(*nbatParams.nbfp.data())),
+                      "Mismatch in the size of host / device data types");
+        DeviceBuffer<Float2> nbfp;
+        initParamLookupTable(&nbfp,
+                             nullptr,
+                             reinterpret_cast<const Float2*>(nbatParams.nbfp.data()),
+                             nbatParams.numTypes * nbatParams.numTypes,
+                             deviceContext);
         nbp->nbfp = nbfp;
 
         if (ic->vdwtype == VanDerWaalsType::Pme)
         {
-            DeviceBuffer<float> nbfp_comb;
-            initParamLookupTable(&nbfp_comb, nullptr, nbatParams.nbfp_comb.data(), nnbfp_comb, deviceContext);
+            static_assert(sizeof(Float2) == 2 * sizeof(decltype(*nbatParams.nbfp_comb.data())),
+                          "Mismatch in the size of host / device data types");
+            DeviceBuffer<Float2> nbfp_comb;
+            initParamLookupTable(&nbfp_comb,
+                                 nullptr,
+                                 reinterpret_cast<const Float2*>(nbatParams.nbfp_comb.data()),
+                                 nbatParams.numTypes,
+                                 deviceContext);
             nbp->nbfp_comb = nbfp_comb;
         }
     }
@@ -482,8 +491,9 @@ void gpu_init_atomdata(NbnxmGpu* nb, const nbnxn_atomdata_t* nbat)
 
     if (useLjCombRule(nb->nbparam->vdwType))
     {
-        static_assert(sizeof(float) == sizeof(*nbat->params().lj_comb.data()),
-                      "Size of the LJ parameters element should be equal to the size of float2.");
+        static_assert(
+                sizeof(Float2) == 2 * sizeof(*nbat->params().lj_comb.data()),
+                "Size of a pair of LJ parameters elements should be equal to the size of Float2.");
         copyToDeviceBuffer(&d_atdat->ljComb,
                            reinterpret_cast<const Float2*>(nbat->params().lj_comb.data()),
                            0,