Make eelType and evdwType scoped enums, + cleanup
[alexxy/gromacs.git] / src / gromacs / nbnxm / opencl / nbnxm_ocl.cpp
index 48e8ed03ad37e1189abc522330af67ef03932779..affbad0811729cdd7ec9e0f4dd5a4ceac5d81cdf 100644 (file)
@@ -162,7 +162,7 @@ static inline void validate_global_work_size(const KernelLaunchConfig& config,
  */
 
 /*! \brief Force-only kernel function names. */
-static const char* nb_kfunc_noener_noprune_ptr[eelTypeNR][evdwTypeNR] = {
+static const char* nb_kfunc_noener_noprune_ptr[c_numElecTypes][c_numVdwTypes] = {
     { "nbnxn_kernel_ElecCut_VdwLJ_F_opencl", "nbnxn_kernel_ElecCut_VdwLJCombGeom_F_opencl",
       "nbnxn_kernel_ElecCut_VdwLJCombLB_F_opencl", "nbnxn_kernel_ElecCut_VdwLJFsw_F_opencl",
       "nbnxn_kernel_ElecCut_VdwLJPsw_F_opencl", "nbnxn_kernel_ElecCut_VdwLJEwCombGeom_F_opencl",
@@ -196,7 +196,7 @@ static const char* nb_kfunc_noener_noprune_ptr[eelTypeNR][evdwTypeNR] = {
 };
 
 /*! \brief Force + energy kernel function pointers. */
-static const char* nb_kfunc_ener_noprune_ptr[eelTypeNR][evdwTypeNR] = {
+static const char* nb_kfunc_ener_noprune_ptr[c_numElecTypes][c_numVdwTypes] = {
     { "nbnxn_kernel_ElecCut_VdwLJ_VF_opencl", "nbnxn_kernel_ElecCut_VdwLJCombGeom_VF_opencl",
       "nbnxn_kernel_ElecCut_VdwLJCombLB_VF_opencl", "nbnxn_kernel_ElecCut_VdwLJFsw_VF_opencl",
       "nbnxn_kernel_ElecCut_VdwLJPsw_VF_opencl", "nbnxn_kernel_ElecCut_VdwLJEwCombGeom_VF_opencl",
@@ -231,7 +231,7 @@ static const char* nb_kfunc_ener_noprune_ptr[eelTypeNR][evdwTypeNR] = {
 };
 
 /*! \brief Force + pruning kernel function pointers. */
-static const char* nb_kfunc_noener_prune_ptr[eelTypeNR][evdwTypeNR] = {
+static const char* nb_kfunc_noener_prune_ptr[c_numElecTypes][c_numVdwTypes] = {
     { "nbnxn_kernel_ElecCut_VdwLJ_F_prune_opencl",
       "nbnxn_kernel_ElecCut_VdwLJCombGeom_F_prune_opencl",
       "nbnxn_kernel_ElecCut_VdwLJCombLB_F_prune_opencl",
@@ -272,7 +272,7 @@ static const char* nb_kfunc_noener_prune_ptr[eelTypeNR][evdwTypeNR] = {
 };
 
 /*! \brief Force + energy + pruning kernel function pointers. */
-static const char* nb_kfunc_ener_prune_ptr[eelTypeNR][evdwTypeNR] = {
+static const char* nb_kfunc_ener_prune_ptr[c_numElecTypes][c_numVdwTypes] = {
     { "nbnxn_kernel_ElecCut_VdwLJ_VF_prune_opencl",
       "nbnxn_kernel_ElecCut_VdwLJCombGeom_VF_prune_opencl",
       "nbnxn_kernel_ElecCut_VdwLJCombLB_VF_prune_opencl",
@@ -341,41 +341,45 @@ static inline cl_kernel selectPruneKernel(cl_kernel kernel_pruneonly[], bool fir
  *  OpenCL kernel objects are cached in nb. If the requested kernel is not
  *  found in the cache, it will be created and the cache will be updated.
  */
-static inline cl_kernel select_nbnxn_kernel(NbnxmGpu* nb, int eeltype, int evdwtype, bool bDoEne, bool bDoPrune)
+static inline cl_kernel
+select_nbnxn_kernel(NbnxmGpu* nb, enum ElecType elecType, enum VdwType vdwType, bool bDoEne, bool bDoPrune)
 {
     const char* kernel_name_to_run;
     cl_kernel*  kernel_ptr;
     cl_int      cl_error;
 
-    GMX_ASSERT(eeltype < eelTypeNR,
+    const int elecTypeIdx = static_cast<int>(elecType);
+    const int vdwTypeIdx  = static_cast<int>(vdwType);
+
+    GMX_ASSERT(elecTypeIdx < c_numElecTypes,
                "The electrostatics type requested is not implemented in the OpenCL kernels.");
-    GMX_ASSERT(evdwtype < evdwTypeNR,
+    GMX_ASSERT(vdwTypeIdx < c_numVdwTypes,
                "The VdW type requested is not implemented in the OpenCL kernels.");
 
     if (bDoEne)
     {
         if (bDoPrune)
         {
-            kernel_name_to_run = nb_kfunc_ener_prune_ptr[eeltype][evdwtype];
-            kernel_ptr         = &(nb->kernel_ener_prune_ptr[eeltype][evdwtype]);
+            kernel_name_to_run = nb_kfunc_ener_prune_ptr[elecTypeIdx][vdwTypeIdx];
+            kernel_ptr         = &(nb->kernel_ener_prune_ptr[elecTypeIdx][vdwTypeIdx]);
         }
         else
         {
-            kernel_name_to_run = nb_kfunc_ener_noprune_ptr[eeltype][evdwtype];
-            kernel_ptr         = &(nb->kernel_ener_noprune_ptr[eeltype][evdwtype]);
+            kernel_name_to_run = nb_kfunc_ener_noprune_ptr[elecTypeIdx][vdwTypeIdx];
+            kernel_ptr         = &(nb->kernel_ener_noprune_ptr[elecTypeIdx][vdwTypeIdx]);
         }
     }
     else
     {
         if (bDoPrune)
         {
-            kernel_name_to_run = nb_kfunc_noener_prune_ptr[eeltype][evdwtype];
-            kernel_ptr         = &(nb->kernel_noener_prune_ptr[eeltype][evdwtype]);
+            kernel_name_to_run = nb_kfunc_noener_prune_ptr[elecTypeIdx][vdwTypeIdx];
+            kernel_ptr         = &(nb->kernel_noener_prune_ptr[elecTypeIdx][vdwTypeIdx]);
         }
         else
         {
-            kernel_name_to_run = nb_kfunc_noener_noprune_ptr[eeltype][evdwtype];
-            kernel_ptr         = &(nb->kernel_noener_noprune_ptr[eeltype][evdwtype]);
+            kernel_name_to_run = nb_kfunc_noener_noprune_ptr[elecTypeIdx][vdwTypeIdx];
+            kernel_ptr         = &(nb->kernel_noener_noprune_ptr[elecTypeIdx][vdwTypeIdx]);
         }
     }
 
@@ -392,7 +396,7 @@ static inline cl_kernel select_nbnxn_kernel(NbnxmGpu* nb, int eeltype, int evdwt
 
 /*! \brief Calculates the amount of shared memory required by the nonbonded kernel in use.
  */
-static inline int calc_shmem_required_nonbonded(int vdwType, bool bPrefetchLjParam)
+static inline int calc_shmem_required_nonbonded(enum VdwType vdwType, bool bPrefetchLjParam)
 {
     int shmem;
 
@@ -438,7 +442,7 @@ static void fillin_ocl_structures(NBParamGpu* nbp, cl_nbparam_params_t* nbparams
     nbparams_params->coulomb_tab_scale = nbp->coulomb_tab_scale;
     nbparams_params->c_rf              = nbp->c_rf;
     nbparams_params->dispersion_shift  = nbp->dispersion_shift;
-    nbparams_params->eeltype           = nbp->eeltype;
+    nbparams_params->elecType          = nbp->elecType;
     nbparams_params->epsfac            = nbp->epsfac;
     nbparams_params->ewaldcoeff_lj     = nbp->ewaldcoeff_lj;
     nbparams_params->ewald_beta        = nbp->ewald_beta;
@@ -451,7 +455,7 @@ static void fillin_ocl_structures(NBParamGpu* nbp, cl_nbparam_params_t* nbparams
     nbparams_params->sh_ewald          = nbp->sh_ewald;
     nbparams_params->sh_lj_ewald       = nbp->sh_lj_ewald;
     nbparams_params->two_k_rf          = nbp->two_k_rf;
-    nbparams_params->vdwtype           = nbp->vdwtype;
+    nbparams_params->vdwType           = nbp->vdwType;
     nbparams_params->vdw_switch        = nbp->vdw_switch;
 }
 
@@ -637,7 +641,7 @@ void gpu_launch_kernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const Nb
     /* kernel launch config */
 
     KernelLaunchConfig config;
-    config.sharedMemorySize = calc_shmem_required_nonbonded(nbp->vdwtype, nb->bPrefetchLjParam);
+    config.sharedMemorySize = calc_shmem_required_nonbonded(nbp->vdwType, nb->bPrefetchLjParam);
     config.blockSize[0]     = c_clSize;
     config.blockSize[1]     = c_clSize;
     config.gridSize[0]      = plist->nsci;
@@ -660,14 +664,14 @@ void gpu_launch_kernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const Nb
     auto*          timingEvent  = bDoTime ? t->interaction[iloc].nb_k.fetchNextEvent() : nullptr;
     constexpr char kernelName[] = "k_calc_nb";
     const auto     kernel =
-            select_nbnxn_kernel(nb, nbp->eeltype, nbp->vdwtype, stepWork.computeEnergy,
+            select_nbnxn_kernel(nb, nbp->elecType, nbp->vdwType, stepWork.computeEnergy,
                                 (plist->haveFreshList && !nb->timers->interaction[iloc].didPrune));
 
 
     // The OpenCL kernel takes int as second to last argument because bool is
     // not supported as a kernel argument type (sizeof(bool) is implementation defined).
     const int computeFshift = static_cast<int>(stepWork.computeVirial);
-    if (useLjCombRule(nb->nbparam->vdwtype))
+    if (useLjCombRule(nb->nbparam->vdwType))
     {
         const auto kernelArgs = prepareGpuKernelArguments(
                 kernel, config, &nbparams_params, &adat->xq, &adat->f, &adat->e_lj, &adat->e_el,