Convert nbnxn_atomdata_t to C++
[alexxy/gromacs.git] / src / gromacs / mdlib / nbnxn_ocl / nbnxn_ocl_data_mgmt.cpp
index 9547252b74b748b0939854abecaa5130fa51147f..4e1b973d672f9b71778f5a5906d5e9e8bfeab9a1 100644 (file)
@@ -145,7 +145,7 @@ static void init_atomdata_first(cl_atomdata_t *ad, int ntypes, gmx_device_runtim
 
     /* An element of the shift_vec device buffer has the same size as one element
        of the host side shift_vec buffer. */
-    ad->shift_vec_elem_size = sizeof(*nbnxn_atomdata_t::shift_vec);
+    ad->shift_vec_elem_size = sizeof(*nbnxn_atomdata_t::shift_vec.data());
 
     ad->shift_vec = clCreateBuffer(runData->context, CL_MEM_READ_ONLY | CL_MEM_HOST_WRITE_ONLY,
                                    SHIFTS * ad->shift_vec_elem_size, nullptr, &cl_error);
@@ -291,7 +291,7 @@ map_interaction_types_to_gpu_kernel_flavors(const interaction_const_t *ic,
 static void init_nbparam(cl_nbparam_t                    *nbp,
                          const interaction_const_t       *ic,
                          const NbnxnListParameters       *listParams,
-                         const nbnxn_atomdata_t          *nbat,
+                         const nbnxn_atomdata_t::Params  &nbatParams,
                          const gmx_device_runtime_data_t *runData)
 {
     cl_int cl_error;
@@ -299,7 +299,7 @@ static void init_nbparam(cl_nbparam_t                    *nbp,
     set_cutoff_parameters(nbp, ic, listParams);
 
     map_interaction_types_to_gpu_kernel_flavors(ic,
-                                                nbat->comb_rule,
+                                                nbatParams.comb_rule,
                                                 &(nbp->eeltype),
                                                 &(nbp->vdwtype));
 
@@ -307,11 +307,11 @@ static void init_nbparam(cl_nbparam_t                    *nbp,
     {
         if (ic->ljpme_comb_rule == ljcrGEOM)
         {
-            GMX_ASSERT(nbat->comb_rule == ljcrGEOM, "Combination rule mismatch!");
+            GMX_ASSERT(nbatParams.comb_rule == ljcrGEOM, "Combination rule mismatch!");
         }
         else
         {
-            GMX_ASSERT(nbat->comb_rule == ljcrLB, "Combination rule mismatch!");
+            GMX_ASSERT(nbatParams.comb_rule == ljcrLB, "Combination rule mismatch!");
         }
     }
     /* generate table for PME */
@@ -342,8 +342,8 @@ static void init_nbparam(cl_nbparam_t                    *nbp,
                            ("clCreateBuffer failed: " + ocl_get_error_string(cl_error)).c_str());
     }
 
-    int nnbfp      = 2*nbat->ntype*nbat->ntype;
-    int nnbfp_comb = 2*nbat->ntype;
+    const int nnbfp      = 2*nbatParams.numTypes*nbatParams.numTypes;
+    const int nnbfp_comb = 2*nbatParams.numTypes;
 
     {
         /* Switched from using textures to using buffers */
@@ -358,8 +358,12 @@ static void init_nbparam(cl_nbparam_t                    *nbp,
             &array_format, nnbfp, 1, 0, nbat->nbfp, &cl_error);
          */
 
-        nbp->nbfp_climg2d = clCreateBuffer(runData->context, CL_MEM_READ_ONLY | CL_MEM_HOST_WRITE_ONLY | CL_MEM_COPY_HOST_PTR,
-                                           nnbfp*sizeof(cl_float), nbat->nbfp, &cl_error);
+        nbp->nbfp_climg2d =
+            clCreateBuffer(runData->context,
+                           CL_MEM_READ_ONLY | CL_MEM_HOST_WRITE_ONLY | CL_MEM_COPY_HOST_PTR,
+                           nnbfp*sizeof(cl_float),
+                           const_cast<float *>(nbatParams.nbfp.data()),
+                           &cl_error);
         GMX_RELEASE_ASSERT(cl_error == CL_SUCCESS,
                            ("clCreateBuffer failed: " + ocl_get_error_string(cl_error)).c_str());
 
@@ -369,8 +373,12 @@ static void init_nbparam(cl_nbparam_t                    *nbp,
             // TODO: decide which alternative is most efficient - textures or buffers.
             /*  nbp->nbfp_comb_climg2d = clCreateImage2D(runData->context, CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR,
                 &array_format, nnbfp_comb, 1, 0, nbat->nbfp_comb, &cl_error);*/
-            nbp->nbfp_comb_climg2d = clCreateBuffer(runData->context, CL_MEM_READ_ONLY | CL_MEM_HOST_WRITE_ONLY | CL_MEM_COPY_HOST_PTR,
-                                                    nnbfp_comb*sizeof(cl_float), nbat->nbfp_comb, &cl_error);
+            nbp->nbfp_comb_climg2d =
+                clCreateBuffer(runData->context,
+                               CL_MEM_READ_ONLY | CL_MEM_HOST_WRITE_ONLY | CL_MEM_COPY_HOST_PTR,
+                               nnbfp_comb*sizeof(cl_float),
+                               const_cast<float *>(nbatParams.nbfp_comb.data()),
+                               &cl_error);
             GMX_RELEASE_ASSERT(cl_error == CL_SUCCESS,
                                ("clCreateBuffer failed: " + ocl_get_error_string(cl_error)).c_str());
         }
@@ -598,10 +606,10 @@ static void nbnxn_gpu_init_kernels(gmx_nbnxn_ocl_t *nb)
 static void nbnxn_ocl_init_const(gmx_nbnxn_ocl_t                *nb,
                                  const interaction_const_t      *ic,
                                  const NbnxnListParameters      *listParams,
-                                 const nbnxn_atomdata_t         *nbat)
+                                 const nbnxn_atomdata_t::Params &nbatParams)
 {
-    init_atomdata_first(nb->atdat, nbat->ntype, nb->dev_rundata);
-    init_nbparam(nb->nbparam, ic, listParams, nbat, nb->dev_rundata);
+    init_atomdata_first(nb->atdat, nbatParams.numTypes, nb->dev_rundata);
+    init_nbparam(nb->nbparam, ic, listParams, nbatParams, nb->dev_rundata);
 }
 
 
@@ -695,7 +703,7 @@ void nbnxn_gpu_init(gmx_nbnxn_ocl_t          **p_nb,
         init_timings(nb->timings);
     }
 
-    nbnxn_ocl_init_const(nb, ic, listParams, nbat);
+    nbnxn_ocl_init_const(nb, ic, listParams, nbat->params());
 
     /* Enable LJ param manual prefetch for AMD or Intel or if we request through env. var.
      * TODO: decide about NVIDIA
@@ -839,7 +847,7 @@ void nbnxn_gpu_upload_shiftvec(gmx_nbnxn_ocl_t        *nb,
     /* only if we have a dynamic box */
     if (nbatom->bDynamicBox || !adat->bShiftVecUploaded)
     {
-        ocl_copy_H2D_async(adat->shift_vec, nbatom->shift_vec, 0,
+        ocl_copy_H2D_async(adat->shift_vec, nbatom->shift_vec.data(), 0,
                            SHIFTS * adat->shift_vec_elem_size, ls, nullptr);
         adat->bShiftVecUploaded = CL_TRUE;
     }
@@ -857,7 +865,7 @@ void nbnxn_gpu_init_atomdata(gmx_nbnxn_ocl_t               *nb,
     cl_atomdata_t   *d_atdat = nb->atdat;
     cl_command_queue ls      = nb->stream[eintLocal];
 
-    natoms    = nbat->natoms;
+    natoms    = nbat->numAtoms();
     realloced = false;
 
     if (bDoTime)
@@ -923,12 +931,12 @@ void nbnxn_gpu_init_atomdata(gmx_nbnxn_ocl_t               *nb,
 
     if (useLjCombRule(nb->nbparam->vdwtype))
     {
-        ocl_copy_H2D_async(d_atdat->lj_comb, nbat->lj_comb, 0,
+        ocl_copy_H2D_async(d_atdat->lj_comb, nbat->params().lj_comb.data(), 0,
                            natoms*sizeof(cl_float2), ls, bDoTime ? timers->atdat.fetchNextEvent() : nullptr);
     }
     else
     {
-        ocl_copy_H2D_async(d_atdat->atom_types, nbat->type, 0,
+        ocl_copy_H2D_async(d_atdat->atom_types, nbat->params().type.data(), 0,
                            natoms*sizeof(int), ls, bDoTime ? timers->atdat.fetchNextEvent() : nullptr);
 
     }