Merge branch release-2021
[alexxy/gromacs.git] / src / gromacs / nbnxm / opencl / nbnxm_ocl_data_mgmt.cpp
index ac99cf926d7a94b25427712ac54888df86e4bcda..290d5f669a675b8e9b5ead2c4851f2bd2412879c 100644 (file)
@@ -124,78 +124,6 @@ static void init_atomdata_first(cl_atomdata_t* ad, int ntypes, const DeviceConte
     ad->nalloc = -1;
 }
 
-/*! \brief Returns the kinds of electrostatics and Vdw OpenCL
- *  kernels that will be used.
- *
- * Respectively, these values are from enum eelOcl and enum
- * evdwOcl. */
-static void map_interaction_types_to_gpu_kernel_flavors(const interaction_const_t* ic,
-                                                        int                        combRule,
-                                                        int*                       gpu_eeltype,
-                                                        int*                       gpu_vdwtype,
-                                                        const DeviceContext&       deviceContext)
-{
-    if (ic->vdwtype == evdwCUT)
-    {
-        switch (ic->vdw_modifier)
-        {
-            case eintmodNONE:
-            case eintmodPOTSHIFT:
-                switch (combRule)
-                {
-                    case ljcrNONE: *gpu_vdwtype = evdwTypeCUT; break;
-                    case ljcrGEOM: *gpu_vdwtype = evdwTypeCUTCOMBGEOM; break;
-                    case ljcrLB: *gpu_vdwtype = evdwTypeCUTCOMBLB; break;
-                    default:
-                        gmx_incons(
-                                "The requested LJ combination rule is not implemented in the "
-                                "OpenCL GPU accelerated kernels!");
-                }
-                break;
-            case eintmodFORCESWITCH: *gpu_vdwtype = evdwTypeFSWITCH; break;
-            case eintmodPOTSWITCH: *gpu_vdwtype = evdwTypePSWITCH; break;
-            default:
-                gmx_incons(
-                        "The requested VdW interaction modifier is not implemented in the GPU "
-                        "accelerated kernels!");
-        }
-    }
-    else if (ic->vdwtype == evdwPME)
-    {
-        if (ic->ljpme_comb_rule == ljcrGEOM)
-        {
-            *gpu_vdwtype = evdwTypeEWALDGEOM;
-        }
-        else
-        {
-            *gpu_vdwtype = evdwTypeEWALDLB;
-        }
-    }
-    else
-    {
-        gmx_incons("The requested VdW type is not implemented in the GPU accelerated kernels!");
-    }
-
-    if (ic->eeltype == eelCUT)
-    {
-        *gpu_eeltype = eelTypeCUT;
-    }
-    else if (EEL_RF(ic->eeltype))
-    {
-        *gpu_eeltype = eelTypeRF;
-    }
-    else if ((EEL_PME(ic->eeltype) || ic->eeltype == eelEWALD))
-    {
-        *gpu_eeltype = nbnxn_gpu_pick_ewald_kernel_type(*ic, deviceContext.deviceInfo());
-    }
-    else
-    {
-        /* Shouldn't happen, as this is checked when choosing Verlet-scheme */
-        gmx_incons(
-                "The requested electrostatics type is not implemented in the GPU accelerated "
-                "kernels!");
-    }
-}
 
 /*! \brief Initializes the nonbonded parameter data structure.
  */
@@ -207,8 +135,8 @@ static void init_nbparam(NBParamGpu*                     nbp,
 {
     set_cutoff_parameters(nbp, ic, listParams);
 
-    map_interaction_types_to_gpu_kernel_flavors(ic, nbatParams.comb_rule, &(nbp->eeltype),
-                                                &(nbp->vdwtype), deviceContext);
+    nbp->vdwType  = nbnxmGpuPickVdwKernelType(ic, nbatParams.comb_rule);
+    nbp->elecType = nbnxmGpuPickElectrostaticsKernelType(ic, deviceContext.deviceInfo());
 
     if (ic->vdwtype == evdwPME)
     {
@@ -223,7 +151,7 @@ static void init_nbparam(NBParamGpu*                     nbp,
     }
     /* generate table for PME */
     nbp->coulomb_tab = nullptr;
-    if (nbp->eeltype == eelTypeEWALD_TAB || nbp->eeltype == eelTypeEWALD_TAB_TWIN)
+    if (nbp->elecType == ElecType::EwaldTab || nbp->elecType == ElecType::EwaldTabTwin)
     {
         GMX_RELEASE_ASSERT(ic->coulombEwaldTables, "Need valid Coulomb Ewald correction tables");
         init_ewald_coulomb_force_table(*ic->coulombEwaldTables, nbp, deviceContext);
@@ -260,8 +188,11 @@ static cl_kernel nbnxn_gpu_create_kernel(NbnxmGpu* nb, const char* kernel_name)
     kernel = clCreateKernel(nb->dev_rundata->program, kernel_name, &cl_error);
     if (CL_SUCCESS != cl_error)
     {
-        gmx_fatal(FARGS, "Failed to create kernel '%s' for GPU #%s: OpenCL error %d", kernel_name,
-                  nb->deviceContext_->deviceInfo().device_name, cl_error);
+        gmx_fatal(FARGS,
+                  "Failed to create kernel '%s' for GPU #%s: OpenCL error %d",
+                  kernel_name,
+                  nb->deviceContext_->deviceInfo().device_name,
+                  cl_error);
     }
 
     return kernel;
@@ -296,8 +227,8 @@ static void nbnxn_ocl_clear_e_fshift(NbnxmGpu* nb)
     cl_error |= clSetKernelArg(zero_e_fshift, arg_no++, sizeof(cl_uint), &shifts);
     GMX_ASSERT(cl_error == CL_SUCCESS, ocl_get_error_string(cl_error).c_str());
 
-    cl_error = clEnqueueNDRangeKernel(ls, zero_e_fshift, 3, nullptr, global_work_size,
-                                      local_work_size, 0, nullptr, nullptr);
+    cl_error = clEnqueueNDRangeKernel(
+            ls, zero_e_fshift, 3, nullptr, global_work_size, local_work_size, 0, nullptr, nullptr);
     GMX_ASSERT(cl_error == CL_SUCCESS, ocl_get_error_string(cl_error).c_str());
 }
 
@@ -473,8 +404,13 @@ void gpu_upload_shiftvec(NbnxmGpu* nb, const nbnxn_atomdata_t* nbatom)
     {
         GMX_ASSERT(sizeof(float) * DIM == sizeof(*nbatom->shift_vec.data()),
                    "Sizes of host- and device-side shift vectors should be the same.");
-        copyToDeviceBuffer(&adat->shift_vec, reinterpret_cast<const float*>(nbatom->shift_vec.data()),
-                           0, SHIFTS * DIM, deviceStream, GpuApiCallBehavior::Async, nullptr);
+        copyToDeviceBuffer(&adat->shift_vec,
+                           reinterpret_cast<const float*>(nbatom->shift_vec.data()),
+                           0,
+                           SHIFTS * DIM,
+                           deviceStream,
+                           GpuApiCallBehavior::Async,
+                           nullptr);
         adat->bShiftVecUploaded = CL_TRUE;
     }
 }
@@ -519,7 +455,7 @@ void gpu_init_atomdata(NbnxmGpu* nb, const nbnxn_atomdata_t* nbat)
         allocateDeviceBuffer(&d_atdat->f, nalloc * DIM, deviceContext);
         allocateDeviceBuffer(&d_atdat->xq, nalloc * (DIM + 1), deviceContext);
 
-        if (useLjCombRule(nb->nbparam->vdwtype))
+        if (useLjCombRule(nb->nbparam->vdwType))
         {
             // Two Lennard-Jones parameters per atom
             allocateDeviceBuffer(&d_atdat->lj_comb, nalloc * 2, deviceContext);
@@ -542,20 +478,29 @@ void gpu_init_atomdata(NbnxmGpu* nb, const nbnxn_atomdata_t* nbat)
         nbnxn_ocl_clear_f(nb, nalloc);
     }
 
-    if (useLjCombRule(nb->nbparam->vdwtype))
+    if (useLjCombRule(nb->nbparam->vdwType))
     {
         GMX_ASSERT(sizeof(float) == sizeof(*nbat->params().lj_comb.data()),
                    "Size of the LJ parameters element should be equal to the size of float2.");
-        copyToDeviceBuffer(&d_atdat->lj_comb, nbat->params().lj_comb.data(), 0, 2 * natoms,
-                           deviceStream, GpuApiCallBehavior::Async,
+        copyToDeviceBuffer(&d_atdat->lj_comb,
+                           nbat->params().lj_comb.data(),
+                           0,
+                           2 * natoms,
+                           deviceStream,
+                           GpuApiCallBehavior::Async,
                            bDoTime ? timers->atdat.fetchNextEvent() : nullptr);
     }
     else
     {
         GMX_ASSERT(sizeof(int) == sizeof(*nbat->params().type.data()),
                    "Sizes of host- and device-side atom types should be the same.");
-        copyToDeviceBuffer(&d_atdat->atom_types, nbat->params().type.data(), 0, natoms, deviceStream,
-                           GpuApiCallBehavior::Async, bDoTime ? timers->atdat.fetchNextEvent() : nullptr);
+        copyToDeviceBuffer(&d_atdat->atom_types,
+                           nbat->params().type.data(),
+                           0,
+                           natoms,
+                           deviceStream,
+                           GpuApiCallBehavior::Async,
+                           bDoTime ? timers->atdat.fetchNextEvent() : nullptr);
     }
 
     if (bDoTime)