Remove nbnxm kernel type from pairlist generation
authorBerk Hess <hess@kth.se>
Thu, 21 Mar 2019 21:09:30 +0000 (22:09 +0100)
committerMark Abraham <mark.j.abraham@gmail.com>
Thu, 18 Apr 2019 15:45:02 +0000 (17:45 +0200)
The type of kernel for computing distances between atom clusters pairs
is no longer set directly by the non-bonded interaction kernel type.
A new type is added for this, although the actual choice of cluster
distance kernel type has actually not changed.
Also moved the declaration of PairlistType to pairlistparams.h.

Change-Id: Iaf38ca7804eed75295e6cb0e1176a23d620c6f0c

src/gromacs/nbnxm/clusterdistancekerneltype.h [new file with mode: 0644]
src/gromacs/nbnxm/grid.cpp
src/gromacs/nbnxm/pairlist.cpp
src/gromacs/nbnxm/pairlist.h
src/gromacs/nbnxm/pairlistparams.h
src/gromacs/nbnxm/pairlistset.h
src/gromacs/nbnxm/pairlistsets.h
src/gromacs/nbnxm/prunekerneldispatch.cpp

diff --git a/src/gromacs/nbnxm/clusterdistancekerneltype.h b/src/gromacs/nbnxm/clusterdistancekerneltype.h
new file mode 100644 (file)
index 0000000..ad5f25f
--- /dev/null
@@ -0,0 +1,110 @@
+/*
+ * This file is part of the GROMACS molecular simulation package.
+ *
+ * Copyright (c) 2019, by the GROMACS development team, led by
+ * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
+ * and including many others, as listed in the AUTHORS file in the
+ * top-level source directory and at http://www.gromacs.org.
+ *
+ * GROMACS is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public License
+ * as published by the Free Software Foundation; either version 2.1
+ * of the License, or (at your option) any later version.
+ *
+ * GROMACS is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with GROMACS; if not, see
+ * http://www.gnu.org/licenses, or write to the Free Software Foundation,
+ * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
+ *
+ * If you want to redistribute modifications to GROMACS, please
+ * consider that scientific software is very special. Version
+ * control is crucial - bugs must be traceable. We will be happy to
+ * consider code for inclusion in the official distribution, but
+ * derived work must not be called official GROMACS. Details are found
+ * in the README & COPYING files - if they are missing, get the
+ * official version at http://www.gromacs.org.
+ *
+ * To help us fund GROMACS development, we humbly ask that you cite
+ * the research papers on the package. Check out http://www.gromacs.org.
+ */
+
+/*! \internal \file
+ *
+ * \brief
+ * Declares the ClusterDistanceKernelType enum
+ *
+ * \author Berk Hess <hess@kth.se>
+ * \ingroup module_nbnxm
+ */
+
+#ifndef GMX_NBNXM_CLUSTERDISTANCEKERNELTYPE_H
+#define GMX_NBNXM_CLUSTERDISTANCEKERNELTYPE_H
+
+#include "gromacs/nbnxm/atomdata.h"
+#include "gromacs/simd/simd.h"
+#include "gromacs/utility/gmxassert.h"
+
+#include "pairlistparams.h"
+
+//! The types of kernel for calculating the distance between pairs of atom clusters
+enum class ClusterDistanceKernelType : int
+{
+    CpuPlainC,    //!< Plain-C for CPU list
+    CpuSimd_4xM,  //!< SIMD for CPU list for j-cluster size matching the SIMD width
+    CpuSimd_2xMM, //!< SIMD for CPU list for j-cluster size matching half the SIMD width
+    Gpu           //!< For GPU list, can be either plain-C or SIMD
+};
+
+//! Return the cluster distance kernel type given the pairlist type and atomdata
+static inline ClusterDistanceKernelType
+getClusterDistanceKernelType(const PairlistType      pairlistType,
+                             const nbnxn_atomdata_t &atomdata)
+{
+    if (pairlistType == PairlistType::HierarchicalNxN)
+    {
+        return ClusterDistanceKernelType::Gpu;
+    }
+    else if (atomdata.XFormat == nbatXYZ)
+    {
+        return ClusterDistanceKernelType::CpuPlainC;
+    }
+    else if (pairlistType == PairlistType::Simple4x2)
+    {
+#if GMX_SIMD && GMX_SIMD_REAL_WIDTH == 2
+        return ClusterDistanceKernelType::CpuSimd_4xM;
+#else
+        GMX_RELEASE_ASSERT(false, "Expect 2-wide SIMD with 4x2 list and nbat SIMD layout");
+#endif
+    }
+    else if (pairlistType == PairlistType::Simple4x4)
+    {
+#if GMX_SIMD && GMX_SIMD_REAL_WIDTH == 4
+        return ClusterDistanceKernelType::CpuSimd_4xM;
+#elif GMX_SIMD && GMX_SIMD_REAL_WIDTH == 8
+        return ClusterDistanceKernelType::CpuSimd_2xMM;
+#else
+        GMX_RELEASE_ASSERT(false, "Expect 4-wide or 8-wide SIMD with 4x4 list and nbat SIMD layout");
+#endif
+    }
+    else
+    {
+        GMX_ASSERT(pairlistType == PairlistType::Simple4x8, "Unhandled pairlist type");
+#if GMX_SIMD && GMX_SIMD_REAL_WIDTH == 8
+        return ClusterDistanceKernelType::CpuSimd_4xM;
+#elif GMX_SIMD && GMX_SIMD_REAL_WIDTH == 16
+        return ClusterDistanceKernelType::CpuSimd_2xMM;
+#else
+        GMX_RELEASE_ASSERT(false, "Expect 8-wide or 16-wide SIMD with 4x4 list and nbat SIMD layout");
+#endif
+    }
+
+    GMX_RELEASE_ASSERT(false, "We should have returned before getting here");
+    return ClusterDistanceKernelType::CpuPlainC;
+}
+
+#endif
index 132ae556e1454a1ecfa50210cb02e7cfae4eaa95..2cd41950f2db46770d5f97c6e49ec4c9ae290f2a 100644 (file)
@@ -61,6 +61,8 @@
 #include "gromacs/simd/simd.h"
 #include "gromacs/simd/vector_operations.h"
 
+#include "pairlistparams.h"
+
 namespace Nbnxm
 {
 
index 9b62b763c9174c628667878010c1cc0c8eba9e35..6e0fde9908bccc65139923a742f0a4b8e5da2fec 100644 (file)
@@ -55,7 +55,6 @@
 #include "gromacs/mdtypes/md_enums.h"
 #include "gromacs/nbnxm/atomdata.h"
 #include "gromacs/nbnxm/gpu_data_mgmt.h"
-#include "gromacs/nbnxm/nbnxm.h"
 #include "gromacs/nbnxm/nbnxm_geometry.h"
 #include "gromacs/nbnxm/nbnxm_simd.h"
 #include "gromacs/pbcutil/ishift.h"
@@ -68,6 +67,7 @@
 #include "gromacs/utility/gmxomp.h"
 #include "gromacs/utility/smalloc.h"
 
+#include "clusterdistancekerneltype.h"
 #include "gridset.h"
 #include "pairlistset.h"
 #include "pairlistsets.h"
@@ -2346,24 +2346,24 @@ static void icell_set_x_simple(int ci,
 static void icell_set_x(int ci,
                         real shx, real shy, real shz,
                         int stride, const real *x,
-                        const Nbnxm::KernelType kernelType,
+                        const ClusterDistanceKernelType kernelType,
                         NbnxnPairlistCpuWork *work)
 {
     switch (kernelType)
     {
 #if GMX_SIMD
 #ifdef GMX_NBNXN_SIMD_4XN
-        case Nbnxm::KernelType::Cpu4xN_Simd_4xN:
+        case ClusterDistanceKernelType::CpuSimd_4xM:
             icell_set_x_simd_4xn(ci, shx, shy, shz, stride, x, work);
             break;
 #endif
 #ifdef GMX_NBNXN_SIMD_2XNN
-        case Nbnxm::KernelType::Cpu4xN_Simd_2xNN:
+        case ClusterDistanceKernelType::CpuSimd_2xMM:
             icell_set_x_simd_2xnn(ci, shx, shy, shz, stride, x, work);
             break;
 #endif
 #endif
-        case Nbnxm::KernelType::Cpu4x4_PlainC:
+        case ClusterDistanceKernelType::CpuPlainC:
             icell_set_x_simple(ci, shx, shy, shz, stride, x, &work->iClusterData);
             break;
         default:
@@ -2376,7 +2376,7 @@ static void icell_set_x(int ci,
 static void icell_set_x(int ci,
                         real shx, real shy, real shz,
                         int stride, const real *x,
-                        Nbnxm::KernelType gmx_unused kernelType,
+                        ClusterDistanceKernelType gmx_unused kernelType,
                         NbnxnPairlistGpuWork *work)
 {
 #if !GMX_SIMD4_HAVE_REAL
@@ -3004,22 +3004,23 @@ static bool pairlistIsSimple(const NbnxnPairlistGpu gmx_unused &pairlist)
     return false;
 }
 
-static void makeClusterListWrapper(NbnxnPairlistCpu       *nbl,
-                                   const Grid gmx_unused  &iGrid,
-                                   const int               ci,
-                                   const Grid             &jGrid,
-                                   const int               firstCell,
-                                   const int               lastCell,
-                                   const bool              excludeSubDiagonal,
-                                   const nbnxn_atomdata_t *nbat,
-                                   const real              rlist2,
-                                   const real              rbb2,
-                                   const Nbnxm::KernelType kernelType,
-                                   int                    *numDistanceChecks)
+static void
+makeClusterListWrapper(NbnxnPairlistCpu                *nbl,
+                       const Grid gmx_unused           &iGrid,
+                       const int                        ci,
+                       const Grid                      &jGrid,
+                       const int                        firstCell,
+                       const int                        lastCell,
+                       const bool                       excludeSubDiagonal,
+                       const nbnxn_atomdata_t          *nbat,
+                       const real                       rlist2,
+                       const real                       rbb2,
+                       const ClusterDistanceKernelType  kernelType,
+                       int                             *numDistanceChecks)
 {
     switch (kernelType)
     {
-        case Nbnxm::KernelType::Cpu4x4_PlainC:
+        case ClusterDistanceKernelType::CpuPlainC:
             makeClusterListSimple(jGrid,
                                   nbl, ci, firstCell, lastCell,
                                   excludeSubDiagonal,
@@ -3028,7 +3029,7 @@ static void makeClusterListWrapper(NbnxnPairlistCpu       *nbl,
                                   numDistanceChecks);
             break;
 #ifdef GMX_NBNXN_SIMD_4XN
-        case Nbnxm::KernelType::Cpu4xN_Simd_4xN:
+        case ClusterDistanceKernelType::CpuSimd_4xM:
             makeClusterListSimd4xn(jGrid,
                                    nbl, ci, firstCell, lastCell,
                                    excludeSubDiagonal,
@@ -3038,7 +3039,7 @@ static void makeClusterListWrapper(NbnxnPairlistCpu       *nbl,
             break;
 #endif
 #ifdef GMX_NBNXN_SIMD_2XNN
-        case Nbnxm::KernelType::Cpu4xN_Simd_2xNN:
+        case ClusterDistanceKernelType::CpuSimd_2xMM:
             makeClusterListSimd2xnn(jGrid,
                                     nbl, ci, firstCell, lastCell,
                                     excludeSubDiagonal,
@@ -3052,18 +3053,19 @@ static void makeClusterListWrapper(NbnxnPairlistCpu       *nbl,
     }
 }
 
-static void makeClusterListWrapper(NbnxnPairlistGpu             *nbl,
-                                   const Grid &gmx_unused        iGrid,
-                                   const int                     ci,
-                                   const Grid                   &jGrid,
-                                   const int                     firstCell,
-                                   const int                     lastCell,
-                                   const bool                    excludeSubDiagonal,
-                                   const nbnxn_atomdata_t       *nbat,
-                                   const real                    rlist2,
-                                   const real                    rbb2,
-                                   Nbnxm::KernelType gmx_unused  kernelType,
-                                   int                          *numDistanceChecks)
+static void
+makeClusterListWrapper(NbnxnPairlistGpu                     *nbl,
+                       const Grid &gmx_unused                iGrid,
+                       const int                             ci,
+                       const Grid                           &jGrid,
+                       const int                             firstCell,
+                       const int                             lastCell,
+                       const bool                            excludeSubDiagonal,
+                       const nbnxn_atomdata_t               *nbat,
+                       const real                            rlist2,
+                       const real                            rbb2,
+                       ClusterDistanceKernelType gmx_unused  kernelType,
+                       int                                  *numDistanceChecks)
 {
     for (int cj = firstCell; cj <= lastCell; cj++)
     {
@@ -3148,7 +3150,7 @@ static void nbnxn_make_pairlist_part(const Nbnxm::GridSet &gridSet,
                                      const nbnxn_atomdata_t *nbat,
                                      const t_blocka &exclusions,
                                      real rlist,
-                                     const Nbnxm::KernelType kernelType,
+                                     const PairlistType pairlistType,
                                      int ci_block,
                                      gmx_bool bFBufferFlag,
                                      int nsubpair_max,
@@ -3182,7 +3184,7 @@ static void nbnxn_make_pairlist_part(const Nbnxm::GridSet &gridSet,
     sync_work(nbl);
     GMX_ASSERT(nbl->na_ci == jGrid.geometry().numAtomsICluster,
                "The cluster sizes in the list and grid should match");
-    nbl->na_cj = Nbnxm::JClusterSizePerKernelType[kernelType];
+    nbl->na_cj = JClusterSizePerListType[pairlistType];
     na_cj_2log = get_2log(nbl->na_cj);
 
     nbl->rlist  = rlist;
@@ -3202,6 +3204,10 @@ static void nbnxn_make_pairlist_part(const Nbnxm::GridSet &gridSet,
 
     const real            rlist2  = nbl->rlist*nbl->rlist;
 
+    // Select the cluster pair distance kernel type
+    const ClusterDistanceKernelType kernelType =
+        getClusterDistanceKernelType(pairlistType, *nbat);
+
     if (haveFep && !pairlistIsSimple(*nbl))
     {
         /* Determine an atom-pair list cut-off distance for FEP atom pairs.
@@ -3944,7 +3950,6 @@ PairlistSet::constructPairlists(const Nbnxm::GridSet          &gridSet,
                                 gmx::ArrayRef<PairsearchWork>  searchWork,
                                 nbnxn_atomdata_t              *nbat,
                                 const t_blocka                *excl,
-                                const Nbnxm::KernelType        kernelType,
                                 const int                      minimumIlistCountForGpuBalancing,
                                 t_nrnb                        *nrnb,
                                 SearchCycleCounting           *searchCycleCounting)
@@ -4083,7 +4088,7 @@ PairlistSet::constructPairlists(const Nbnxm::GridSet          &gridSet,
                         nbnxn_make_pairlist_part(gridSet, iGrid, jGrid,
                                                  &work, nbat, *excl,
                                                  rlist,
-                                                 kernelType,
+                                                 params_.pairlistType,
                                                  ci_block,
                                                  nbat->bUseBufferFlags,
                                                  nsubpair_target,
@@ -4097,7 +4102,7 @@ PairlistSet::constructPairlists(const Nbnxm::GridSet          &gridSet,
                         nbnxn_make_pairlist_part(gridSet, iGrid, jGrid,
                                                  &work, nbat, *excl,
                                                  rlist,
-                                                 kernelType,
+                                                 params_.pairlistType,
                                                  ci_block,
                                                  nbat->bUseBufferFlags,
                                                  nsubpair_target,
@@ -4262,12 +4267,11 @@ PairlistSets::construct(const InteractionLocality  iLocality,
                         PairSearch                *pairSearch,
                         nbnxn_atomdata_t          *nbat,
                         const t_blocka            *excl,
-                        const Nbnxm::KernelType    kernelType,
                         const int64_t              step,
                         t_nrnb                    *nrnb)
 {
     pairlistSet(iLocality).constructPairlists(pairSearch->gridSet(), pairSearch->work(),
-                                              nbat, excl, kernelType, minimumIlistCountForGpuBalancing_,
+                                              nbat, excl, minimumIlistCountForGpuBalancing_,
                                               nrnb, &pairSearch->cycleCounting_);
 
     if (iLocality == Nbnxm::InteractionLocality::Local)
@@ -4300,7 +4304,6 @@ nonbonded_verlet_t::constructPairlist(const Nbnxm::InteractionLocality  iLocalit
                                       t_nrnb                           *nrnb)
 {
     pairlistSets_->construct(iLocality, pairSearch_.get(), nbat.get(), excl,
-                             kernelSetup_.kernelType,
                              step, nrnb);
 
     if (useGpu())
index da19d70f9e208a8eaf3054508847c902ce23e29b..19d898fcd0b9c603a162ae7d523ed53e62931036 100644 (file)
@@ -36,8 +36,6 @@
 #ifndef GMX_NBNXM_PAIRLIST_H
 #define GMX_NBNXM_PAIRLIST_H
 
-#include "config.h"
-
 #include <cstddef>
 
 #include "gromacs/gpu_utils/hostallocator.h"
@@ -53,6 +51,7 @@
 #include "gromacs/nbnxm/constants.h"
 
 #include "locality.h"
+#include "pairlistparams.h"
 
 struct NbnxnPairlistCpuWork;
 struct NbnxnPairlistGpuWork;
@@ -66,58 +65,6 @@ using AlignedVector = std::vector < T, gmx::AlignedAllocator < T>>;
 template<typename T>
 using FastVector = std::vector < T, gmx::DefaultInitializationAllocator < T>>;
 
-/* With CPU kernels the i-cluster size is always 4 atoms. */
-static constexpr int c_nbnxnCpuIClusterSize = 4;
-
-/* With GPU kernels the i and j cluster size is 8 atoms for CUDA and can be set at compile time for OpenCL */
-#if GMX_GPU == GMX_GPU_OPENCL
-static constexpr int c_nbnxnGpuClusterSize = GMX_OPENCL_NB_CLUSTER_SIZE;
-#else
-static constexpr int c_nbnxnGpuClusterSize = 8;
-#endif
-
-/* The number of clusters in a pair-search cell, used for GPU */
-static constexpr int c_gpuNumClusterPerCellZ = 2;
-static constexpr int c_gpuNumClusterPerCellY = 2;
-static constexpr int c_gpuNumClusterPerCellX = 2;
-static constexpr int c_gpuNumClusterPerCell  = c_gpuNumClusterPerCellZ*c_gpuNumClusterPerCellY*c_gpuNumClusterPerCellX;
-
-
-/* In CUDA the number of threads in a warp is 32 and we have cluster pairs
- * of 8*8=64 atoms, so it's convenient to store data for cluster pair halves.
- */
-static constexpr int c_nbnxnGpuClusterpairSplit = 2;
-
-/* The fixed size of the exclusion mask array for a half cluster pair */
-static constexpr int c_nbnxnGpuExclSize = c_nbnxnGpuClusterSize*c_nbnxnGpuClusterSize/c_nbnxnGpuClusterpairSplit;
-
-//! The available pair list types
-enum class PairlistType : int
-{
-    Simple4x2,
-    Simple4x4,
-    Simple4x8,
-    HierarchicalNxN,
-    Count
-};
-
-//! Gives the i-cluster size for each pairlist type
-static constexpr gmx::EnumerationArray<PairlistType, int> IClusterSizePerListType =
-{ {
-      c_nbnxnCpuIClusterSize,
-      c_nbnxnCpuIClusterSize,
-      c_nbnxnCpuIClusterSize,
-      c_nbnxnGpuClusterSize
-  } };
-//! Gives the j-cluster size for each pairlist type
-static constexpr gmx::EnumerationArray<PairlistType, int> JClusterSizePerListType =
-{ {
-      2,
-      4,
-      8,
-      c_nbnxnGpuClusterSize
-  } };
-
 /* A buffer data structure of 64 bytes
  * to be placed at the beginning and end of structs
  * to avoid cache invalidation of the real contents
index b897b6551ef738387a77282dc85bef6273433cdb..78caf48e3cb6a3087439e465346d1b8b0b1fd179 100644 (file)
@@ -36,9 +36,7 @@
 /*! \internal \file
  *
  * \brief
- * Declares the PairlistParams class
- *
- * This class holds the Nbnxm pairlist parameters.
+ * Declares the PairlistType enum and PairlistParams class
  *
  * \author Berk Hess <hess@kth.se>
  * \ingroup module_nbnxm
 #ifndef GMX_NBNXM_PAIRLISTPARAMS_H
 #define GMX_NBNXM_PAIRLISTPARAMS_H
 
+#include "config.h"
+
+#include "gromacs/utility/enumerationhelpers.h"
 #include "gromacs/utility/real.h"
 
 #include "locality.h"
 
-enum class PairlistType;
-
 namespace Nbnxm
 {
 enum class KernelType;
 }
 
+//! The i-cluster size for CPU kernels, always 4 atoms
+static constexpr int c_nbnxnCpuIClusterSize = 4;
+
+//! The i- and j-cluster size for GPU lists, 8 atoms for CUDA, set at compile time for OpenCL
+#if GMX_GPU == GMX_GPU_OPENCL
+static constexpr int c_nbnxnGpuClusterSize = GMX_OPENCL_NB_CLUSTER_SIZE;
+#else
+static constexpr int c_nbnxnGpuClusterSize = 8;
+#endif
+
+//! The number of clusters along Z in a pair-search grid cell for GPU lists
+static constexpr int c_gpuNumClusterPerCellZ = 2;
+//! The number of clusters along Y in a pair-search grid cell for GPU lists
+static constexpr int c_gpuNumClusterPerCellY = 2;
+//! The number of clusters along X in a pair-search grid cell for GPU lists
+static constexpr int c_gpuNumClusterPerCellX = 2;
+//! The number of clusters in a pair-search grid cell for GPU lists
+static constexpr int c_gpuNumClusterPerCell  = c_gpuNumClusterPerCellZ*c_gpuNumClusterPerCellY*c_gpuNumClusterPerCellX;
+
+
+/*! \brief The number of sub-parts used for data storage for a GPU cluster pair
+ *
+ * In CUDA the number of threads in a warp is 32 and we have cluster pairs
+ * of 8*8=64 atoms, so it's convenient to store data for cluster pair halves.
+ */
+static constexpr int c_nbnxnGpuClusterpairSplit = 2;
+
+//! The fixed size of the exclusion mask array for a half GPU cluster pair
+static constexpr int c_nbnxnGpuExclSize = c_nbnxnGpuClusterSize*c_nbnxnGpuClusterSize/c_nbnxnGpuClusterpairSplit;
+
+//! The available pair list types
+enum class PairlistType : int
+{
+    Simple4x2,
+    Simple4x4,
+    Simple4x8,
+    HierarchicalNxN,
+    Count
+};
+
+//! Gives the i-cluster size for each pairlist type
+static constexpr gmx::EnumerationArray<PairlistType, int> IClusterSizePerListType =
+{ {
+      c_nbnxnCpuIClusterSize,
+      c_nbnxnCpuIClusterSize,
+      c_nbnxnCpuIClusterSize,
+      c_nbnxnGpuClusterSize
+  } };
+//! Gives the j-cluster size for each pairlist type
+static constexpr gmx::EnumerationArray<PairlistType, int> JClusterSizePerListType =
+{ {
+      2,
+      4,
+      8,
+      c_nbnxnGpuClusterSize
+  } };
 
 /*! \internal
  * \brief The setup for generating and pruning the nbnxn pair list.
index 1eea57d0e44d08a90b7f387d5c7330c433332bf6..758d10f19660e2d505584c5249d71af465b169a9 100644 (file)
 #define GMX_NBNXM_PAIRLISTSET_H
 
 #include "gromacs/math/vectypes.h"
-#include "gromacs/nbnxm/nbnxm.h"
 #include "gromacs/nbnxm/pairlist.h"
 #include "gromacs/utility/basedefinitions.h"
 #include "gromacs/utility/real.h"
 
 #include "locality.h"
 
+struct nbnxn_atomdata_t;
 struct PairlistParams;
 struct PairsearchWork;
 struct SearchCycleCounting;
+struct t_blocka;
 struct t_nrnb;
 
 namespace Nbnxm
@@ -84,15 +85,13 @@ class PairlistSet
                                 gmx::ArrayRef<PairsearchWork>  searchWork,
                                 nbnxn_atomdata_t              *nbat,
                                 const t_blocka                *excl,
-                                Nbnxm::KernelType              kernelType,
                                 int                            minimumIlistCountForGpuBalancing,
                                 t_nrnb                        *nrnb,
                                 SearchCycleCounting           *searchCycleCounting);
 
         //! Dispatch the kernel for dynamic pairlist pruning
         void dispatchPruneKernel(const nbnxn_atomdata_t *nbat,
-                                 const rvec             *shift_vec,
-                                 Nbnxm::KernelType       kernelType);
+                                 const rvec             *shift_vec);
 
         //! Returns the locality
         Nbnxm::InteractionLocality locality() const
index ccfc0992d5bd309c2d60a89b52c36c71a3652575..41e8dc25b50465102a88c8724a099867ffc598b7 100644 (file)
@@ -59,11 +59,6 @@ class PairSearch;
 struct t_blocka;
 struct t_nrnb;
 
-namespace Nbnxm
-{
-enum class KernelType;
-}
-
 
 class PairlistSets
 {
@@ -77,15 +72,13 @@ class PairlistSets
                        PairSearch                 *pairSearch,
                        nbnxn_atomdata_t           *nbat,
                        const t_blocka             *excl,
-                       Nbnxm::KernelType           kernelbType,
                        int64_t                     step,
                        t_nrnb                     *nrnb);
 
         //! Dispatches the dynamic pruning kernel for the given locality
         void dispatchPruneKernel(Nbnxm::InteractionLocality  iLocality,
                                  const nbnxn_atomdata_t     *nbat,
-                                 const rvec                 *shift_vec,
-                                 Nbnxm::KernelType           kernelType);
+                                 const rvec                 *shift_vec);
 
         //! Returns the pair list parameters
         const PairlistParams &params() const
index 2ee9a5186115cb74a8235f18f42f540f1a01866b..899b9489be6af794447aebaebf98b4fac1f85bd9 100644 (file)
@@ -39,6 +39,7 @@
 #include "gromacs/nbnxm/nbnxm.h"
 #include "gromacs/utility/gmxassert.h"
 
+#include "clusterdistancekerneltype.h"
 #include "pairlistset.h"
 #include "pairlistsets.h"
 #include "kernels_reference/kernel_ref_prune.h"
 void
 PairlistSets::dispatchPruneKernel(const Nbnxm::InteractionLocality  iLocality,
                                   const nbnxn_atomdata_t           *nbat,
-                                  const rvec                       *shift_vec,
-                                  const Nbnxm::KernelType           kernelType)
+                                  const rvec                       *shift_vec)
 {
-    pairlistSet(iLocality).dispatchPruneKernel(nbat, shift_vec, kernelType);
+    pairlistSet(iLocality).dispatchPruneKernel(nbat, shift_vec);
 }
 
 void
 PairlistSet::dispatchPruneKernel(const nbnxn_atomdata_t  *nbat,
-                                 const rvec              *shift_vec,
-                                 const Nbnxm::KernelType  kernelType)
+                                 const rvec              *shift_vec)
 {
     const real rlistInner = params_.rlistInner;
 
@@ -72,15 +71,15 @@ PairlistSet::dispatchPruneKernel(const nbnxn_atomdata_t  *nbat,
     {
         NbnxnPairlistCpu *nbl = &cpuLists_[i];
 
-        switch (kernelType)
+        switch (getClusterDistanceKernelType(params_.pairlistType, *nbat))
         {
-            case Nbnxm::KernelType::Cpu4xN_Simd_4xN:
+            case ClusterDistanceKernelType::CpuSimd_4xM:
                 nbnxn_kernel_prune_4xn(nbl, nbat, shift_vec, rlistInner);
                 break;
-            case Nbnxm::KernelType::Cpu4xN_Simd_2xNN:
+            case ClusterDistanceKernelType::CpuSimd_2xMM:
                 nbnxn_kernel_prune_2xnn(nbl, nbat, shift_vec, rlistInner);
                 break;
-            case Nbnxm::KernelType::Cpu4x4_PlainC:
+            case ClusterDistanceKernelType::CpuPlainC:
                 nbnxn_kernel_prune_ref(nbl, nbat, shift_vec, rlistInner);
                 break;
             default:
@@ -93,7 +92,7 @@ void
 nonbonded_verlet_t::dispatchPruneKernelCpu(const Nbnxm::InteractionLocality  iLocality,
                                            const rvec                       *shift_vec)
 {
-    pairlistSets_->dispatchPruneKernel(iLocality, nbat.get(), shift_vec, kernelSetup_.kernelType);
+    pairlistSets_->dispatchPruneKernel(iLocality, nbat.get(), shift_vec);
 }
 
 void nonbonded_verlet_t::dispatchPruneKernelGpu(int64_t step)