Remove nbnxm kernel type from pairlist generation
[alexxy/gromacs.git] / src / gromacs / nbnxm / pairlist.cpp
index 9b62b763c9174c628667878010c1cc0c8eba9e35..6e0fde9908bccc65139923a742f0a4b8e5da2fec 100644 (file)
@@ -55,7 +55,6 @@
 #include "gromacs/mdtypes/md_enums.h"
 #include "gromacs/nbnxm/atomdata.h"
 #include "gromacs/nbnxm/gpu_data_mgmt.h"
-#include "gromacs/nbnxm/nbnxm.h"
 #include "gromacs/nbnxm/nbnxm_geometry.h"
 #include "gromacs/nbnxm/nbnxm_simd.h"
 #include "gromacs/pbcutil/ishift.h"
@@ -68,6 +67,7 @@
 #include "gromacs/utility/gmxomp.h"
 #include "gromacs/utility/smalloc.h"
 
+#include "clusterdistancekerneltype.h"
 #include "gridset.h"
 #include "pairlistset.h"
 #include "pairlistsets.h"
@@ -2346,24 +2346,24 @@ static void icell_set_x_simple(int ci,
 static void icell_set_x(int ci,
                         real shx, real shy, real shz,
                         int stride, const real *x,
-                        const Nbnxm::KernelType kernelType,
+                        const ClusterDistanceKernelType kernelType,
                         NbnxnPairlistCpuWork *work)
 {
     switch (kernelType)
     {
 #if GMX_SIMD
 #ifdef GMX_NBNXN_SIMD_4XN
-        case Nbnxm::KernelType::Cpu4xN_Simd_4xN:
+        case ClusterDistanceKernelType::CpuSimd_4xM:
             icell_set_x_simd_4xn(ci, shx, shy, shz, stride, x, work);
             break;
 #endif
 #ifdef GMX_NBNXN_SIMD_2XNN
-        case Nbnxm::KernelType::Cpu4xN_Simd_2xNN:
+        case ClusterDistanceKernelType::CpuSimd_2xMM:
             icell_set_x_simd_2xnn(ci, shx, shy, shz, stride, x, work);
             break;
 #endif
 #endif
-        case Nbnxm::KernelType::Cpu4x4_PlainC:
+        case ClusterDistanceKernelType::CpuPlainC:
             icell_set_x_simple(ci, shx, shy, shz, stride, x, &work->iClusterData);
             break;
         default:
@@ -2376,7 +2376,7 @@ static void icell_set_x(int ci,
 static void icell_set_x(int ci,
                         real shx, real shy, real shz,
                         int stride, const real *x,
-                        Nbnxm::KernelType gmx_unused kernelType,
+                        ClusterDistanceKernelType gmx_unused kernelType,
                         NbnxnPairlistGpuWork *work)
 {
 #if !GMX_SIMD4_HAVE_REAL
@@ -3004,22 +3004,23 @@ static bool pairlistIsSimple(const NbnxnPairlistGpu gmx_unused &pairlist)
     return false;
 }
 
-static void makeClusterListWrapper(NbnxnPairlistCpu       *nbl,
-                                   const Grid gmx_unused  &iGrid,
-                                   const int               ci,
-                                   const Grid             &jGrid,
-                                   const int               firstCell,
-                                   const int               lastCell,
-                                   const bool              excludeSubDiagonal,
-                                   const nbnxn_atomdata_t *nbat,
-                                   const real              rlist2,
-                                   const real              rbb2,
-                                   const Nbnxm::KernelType kernelType,
-                                   int                    *numDistanceChecks)
+static void
+makeClusterListWrapper(NbnxnPairlistCpu                *nbl,
+                       const Grid gmx_unused           &iGrid,
+                       const int                        ci,
+                       const Grid                      &jGrid,
+                       const int                        firstCell,
+                       const int                        lastCell,
+                       const bool                       excludeSubDiagonal,
+                       const nbnxn_atomdata_t          *nbat,
+                       const real                       rlist2,
+                       const real                       rbb2,
+                       const ClusterDistanceKernelType  kernelType,
+                       int                             *numDistanceChecks)
 {
     switch (kernelType)
     {
-        case Nbnxm::KernelType::Cpu4x4_PlainC:
+        case ClusterDistanceKernelType::CpuPlainC:
             makeClusterListSimple(jGrid,
                                   nbl, ci, firstCell, lastCell,
                                   excludeSubDiagonal,
@@ -3028,7 +3029,7 @@ static void makeClusterListWrapper(NbnxnPairlistCpu       *nbl,
                                   numDistanceChecks);
             break;
 #ifdef GMX_NBNXN_SIMD_4XN
-        case Nbnxm::KernelType::Cpu4xN_Simd_4xN:
+        case ClusterDistanceKernelType::CpuSimd_4xM:
             makeClusterListSimd4xn(jGrid,
                                    nbl, ci, firstCell, lastCell,
                                    excludeSubDiagonal,
@@ -3038,7 +3039,7 @@ static void makeClusterListWrapper(NbnxnPairlistCpu       *nbl,
             break;
 #endif
 #ifdef GMX_NBNXN_SIMD_2XNN
-        case Nbnxm::KernelType::Cpu4xN_Simd_2xNN:
+        case ClusterDistanceKernelType::CpuSimd_2xMM:
             makeClusterListSimd2xnn(jGrid,
                                     nbl, ci, firstCell, lastCell,
                                     excludeSubDiagonal,
@@ -3052,18 +3053,19 @@ static void makeClusterListWrapper(NbnxnPairlistCpu       *nbl,
     }
 }
 
-static void makeClusterListWrapper(NbnxnPairlistGpu             *nbl,
-                                   const Grid &gmx_unused        iGrid,
-                                   const int                     ci,
-                                   const Grid                   &jGrid,
-                                   const int                     firstCell,
-                                   const int                     lastCell,
-                                   const bool                    excludeSubDiagonal,
-                                   const nbnxn_atomdata_t       *nbat,
-                                   const real                    rlist2,
-                                   const real                    rbb2,
-                                   Nbnxm::KernelType gmx_unused  kernelType,
-                                   int                          *numDistanceChecks)
+static void
+makeClusterListWrapper(NbnxnPairlistGpu                     *nbl,
+                       const Grid &gmx_unused                iGrid,
+                       const int                             ci,
+                       const Grid                           &jGrid,
+                       const int                             firstCell,
+                       const int                             lastCell,
+                       const bool                            excludeSubDiagonal,
+                       const nbnxn_atomdata_t               *nbat,
+                       const real                            rlist2,
+                       const real                            rbb2,
+                       ClusterDistanceKernelType gmx_unused  kernelType,
+                       int                                  *numDistanceChecks)
 {
     for (int cj = firstCell; cj <= lastCell; cj++)
     {
@@ -3148,7 +3150,7 @@ static void nbnxn_make_pairlist_part(const Nbnxm::GridSet &gridSet,
                                      const nbnxn_atomdata_t *nbat,
                                      const t_blocka &exclusions,
                                      real rlist,
-                                     const Nbnxm::KernelType kernelType,
+                                     const PairlistType pairlistType,
                                      int ci_block,
                                      gmx_bool bFBufferFlag,
                                      int nsubpair_max,
@@ -3182,7 +3184,7 @@ static void nbnxn_make_pairlist_part(const Nbnxm::GridSet &gridSet,
     sync_work(nbl);
     GMX_ASSERT(nbl->na_ci == jGrid.geometry().numAtomsICluster,
                "The cluster sizes in the list and grid should match");
-    nbl->na_cj = Nbnxm::JClusterSizePerKernelType[kernelType];
+    nbl->na_cj = JClusterSizePerListType[pairlistType];
     na_cj_2log = get_2log(nbl->na_cj);
 
     nbl->rlist  = rlist;
@@ -3202,6 +3204,10 @@ static void nbnxn_make_pairlist_part(const Nbnxm::GridSet &gridSet,
 
     const real            rlist2  = nbl->rlist*nbl->rlist;
 
+    // Select the cluster pair distance kernel type
+    const ClusterDistanceKernelType kernelType =
+        getClusterDistanceKernelType(pairlistType, *nbat);
+
     if (haveFep && !pairlistIsSimple(*nbl))
     {
         /* Determine an atom-pair list cut-off distance for FEP atom pairs.
@@ -3944,7 +3950,6 @@ PairlistSet::constructPairlists(const Nbnxm::GridSet          &gridSet,
                                 gmx::ArrayRef<PairsearchWork>  searchWork,
                                 nbnxn_atomdata_t              *nbat,
                                 const t_blocka                *excl,
-                                const Nbnxm::KernelType        kernelType,
                                 const int                      minimumIlistCountForGpuBalancing,
                                 t_nrnb                        *nrnb,
                                 SearchCycleCounting           *searchCycleCounting)
@@ -4083,7 +4088,7 @@ PairlistSet::constructPairlists(const Nbnxm::GridSet          &gridSet,
                         nbnxn_make_pairlist_part(gridSet, iGrid, jGrid,
                                                  &work, nbat, *excl,
                                                  rlist,
-                                                 kernelType,
+                                                 params_.pairlistType,
                                                  ci_block,
                                                  nbat->bUseBufferFlags,
                                                  nsubpair_target,
@@ -4097,7 +4102,7 @@ PairlistSet::constructPairlists(const Nbnxm::GridSet          &gridSet,
                         nbnxn_make_pairlist_part(gridSet, iGrid, jGrid,
                                                  &work, nbat, *excl,
                                                  rlist,
-                                                 kernelType,
+                                                 params_.pairlistType,
                                                  ci_block,
                                                  nbat->bUseBufferFlags,
                                                  nsubpair_target,
@@ -4262,12 +4267,11 @@ PairlistSets::construct(const InteractionLocality  iLocality,
                         PairSearch                *pairSearch,
                         nbnxn_atomdata_t          *nbat,
                         const t_blocka            *excl,
-                        const Nbnxm::KernelType    kernelType,
                         const int64_t              step,
                         t_nrnb                    *nrnb)
 {
     pairlistSet(iLocality).constructPairlists(pairSearch->gridSet(), pairSearch->work(),
-                                              nbat, excl, kernelType, minimumIlistCountForGpuBalancing_,
+                                              nbat, excl, minimumIlistCountForGpuBalancing_,
                                               nrnb, &pairSearch->cycleCounting_);
 
     if (iLocality == Nbnxm::InteractionLocality::Local)
@@ -4300,7 +4304,6 @@ nonbonded_verlet_t::constructPairlist(const Nbnxm::InteractionLocality  iLocalit
                                       t_nrnb                           *nrnb)
 {
     pairlistSets_->construct(iLocality, pairSearch_.get(), nbat.get(), excl,
-                             kernelSetup_.kernelType,
                              step, nrnb);
 
     if (useGpu())