Expose and test getVdwKernelType
authorJoe Jordan <ejjordan12@gmail.com>
Wed, 9 Jun 2021 11:19:25 +0000 (11:19 +0000)
committerPaul Bauer <paul.bauer.q@gmail.com>
Wed, 9 Jun 2021 11:19:25 +0000 (11:19 +0000)
src/gromacs/nbnxm/kernel_common.h
src/gromacs/nbnxm/kerneldispatch.cpp
src/gromacs/nbnxm/tests/kernelsetup.cpp

index f107ea591836a77d5789886622fa0af9280251a5..93d4b8c4700417af20685f7b52f598c5a9459524 100644 (file)
@@ -55,6 +55,9 @@
 
 struct interaction_const_t;
 enum class CoulombInteractionType : int;
+enum class VanDerWaalsType : int;
+enum class InteractionModifiers : int;
+enum class LongRangeVdW : int;
 
 namespace Nbnxm
 {
@@ -126,6 +129,13 @@ enum
     vdwktNR_ref
 };
 
+//! \brief Lookup function for Vdw kernel type
+int getVdwKernelType(Nbnxm::KernelType    kernelType,
+                     LJCombinationRule    ljCombinationRule,
+                     VanDerWaalsType      vanDerWaalsType,
+                     InteractionModifiers interactionModifiers,
+                     LongRangeVdW         longRangeVdW);
+
 /*! \brief Clears the force buffer.
  *
  * Either the whole buffer is cleared or only the parts used
index 174dccf1fa3ba7a18998325339372e1308438513..f7229169e3f30ad76794730f275cd1198f54e2f5 100644 (file)
@@ -171,43 +171,45 @@ CoulombKernelType getCoulombKernelType(const Nbnxm::EwaldExclusionType ewaldExcl
     }
 }
 
-static int getVdwKernelType(const Nbnxm::KernelSetup&       kernelSetup,
-                            const nbnxn_atomdata_t::Params& nbatParams,
-                            const interaction_const_t&      ic)
+int getVdwKernelType(const Nbnxm::KernelType    kernelType,
+                     const LJCombinationRule    ljCombinationRule,
+                     const VanDerWaalsType      vanDerWaalsType,
+                     const InteractionModifiers interactionModifiers,
+                     const LongRangeVdW         longRangeVdW)
 {
-    if (ic.vdwtype == VanDerWaalsType::Cut)
+    if (vanDerWaalsType == VanDerWaalsType::Cut)
     {
-        switch (ic.vdw_modifier)
+        switch (interactionModifiers)
         {
             case InteractionModifiers::None:
             case InteractionModifiers::PotShift:
-                switch (nbatParams.ljCombinationRule)
+                switch (ljCombinationRule)
                 {
                     case LJCombinationRule::Geometric: return vdwktLJCUT_COMBGEOM;
                     case LJCombinationRule::LorentzBerthelot: return vdwktLJCUT_COMBLB;
                     case LJCombinationRule::None: return vdwktLJCUT_COMBNONE;
-                    default: gmx_incons("Unknown combination rule");
+                    default: GMX_THROW(gmx::InvalidInputError("Unknown combination rule"));
                 }
             case InteractionModifiers::ForceSwitch: return vdwktLJFORCESWITCH;
             case InteractionModifiers::PotSwitch: return vdwktLJPOTSWITCH;
             default:
                 std::string errorMsg =
                         gmx::formatString("Unsupported VdW interaction modifier %s (%d)",
-                                          enumValueToString(ic.vdw_modifier),
-                                          static_cast<int>(ic.vdw_modifier));
-                gmx_incons(errorMsg);
+                                          enumValueToString(interactionModifiers),
+                                          static_cast<int>(interactionModifiers));
+                GMX_THROW(gmx::InvalidInputError(errorMsg));
         }
     }
-    else if (ic.vdwtype == VanDerWaalsType::Pme)
+    else if (vanDerWaalsType == VanDerWaalsType::Pme)
     {
-        if (ic.ljpme_comb_rule == LongRangeVdW::Geom)
+        if (longRangeVdW == LongRangeVdW::Geom)
         {
             return vdwktLJEWALDCOMBGEOM;
         }
         else
         {
             /* At setup we (should have) selected the C reference kernel */
-            GMX_RELEASE_ASSERT(kernelSetup.kernelType == Nbnxm::KernelType::Cpu4x4_PlainC,
+            GMX_RELEASE_ASSERT(kernelType == Nbnxm::KernelType::Cpu4x4_PlainC,
                                "Only the C reference nbnxn SIMD kernel supports LJ-PME with LB "
                                "combination rules");
             return vdwktLJEWALDCOMBLB;
@@ -216,9 +218,9 @@ static int getVdwKernelType(const Nbnxm::KernelSetup&       kernelSetup,
     else
     {
         std::string errorMsg = gmx::formatString("Unsupported VdW interaction type %s (%d)",
-                                                 enumValueToString(ic.vdwtype),
-                                                 static_cast<int>(ic.vdwtype));
-        gmx_incons(errorMsg);
+                                                 enumValueToString(vanDerWaalsType),
+                                                 static_cast<int>(vanDerWaalsType));
+        GMX_THROW(gmx::InvalidInputError(errorMsg));
     }
 }
 
@@ -255,7 +257,8 @@ static void nbnxn_kernel_cpu(const PairlistSet&             pairlistSet,
 
     const int coulkt = static_cast<int>(getCoulombKernelType(
             kernelSetup.ewaldExclusionType, ic.eeltype, (ic.rcoulomb == ic.rvdw)));
-    const int vdwkt  = getVdwKernelType(kernelSetup, nbatParams, ic);
+    const int vdwkt  = getVdwKernelType(
+            kernelSetup.kernelType, nbatParams.ljCombinationRule, ic.vdwtype, ic.vdw_modifier, ic.ljpme_comb_rule);
 
     gmx::ArrayRef<const NbnxnPairlistCpu> pairlists = pairlistSet.cpuLists();
 
index aef62d6704362096b0a49ac76a09495c1ea9e345..2ff8c00fc4e8e9f2374b0f62f9ac74e1c3f2000e 100644 (file)
@@ -92,6 +92,126 @@ TEST(KernelSetupTest, getCoulombKernelTypeEwaldTwin)
               CoulombKernelType::EwaldTwin);
 }
 
+TEST(KernelSetupTest, getVdwKernelTypeLjCutCombGeomNone)
+{
+    EXPECT_EQ(getVdwKernelType(Nbnxm::KernelType::NotSet,
+                               LJCombinationRule::Geometric,
+                               VanDerWaalsType::Cut,
+                               InteractionModifiers::None,
+                               LongRangeVdW::Count),
+              vdwktLJCUT_COMBGEOM);
+}
+
+TEST(KernelSetupTest, getVdwKernelTypeLjCutCombGeomPotShift)
+{
+    EXPECT_EQ(getVdwKernelType(Nbnxm::KernelType::NotSet,
+                               LJCombinationRule::Geometric,
+                               VanDerWaalsType::Cut,
+                               InteractionModifiers::PotShift,
+                               LongRangeVdW::Count),
+              vdwktLJCUT_COMBGEOM);
+}
+
+TEST(KernelSetupTest, getVdwKernelTypeLjCutCombLBNone)
+{
+    EXPECT_EQ(getVdwKernelType(Nbnxm::KernelType::NotSet,
+                               LJCombinationRule::LorentzBerthelot,
+                               VanDerWaalsType::Cut,
+                               InteractionModifiers::None,
+                               LongRangeVdW::Count),
+              vdwktLJCUT_COMBLB);
+}
+
+TEST(KernelSetupTest, getVdwKernelTypeLjCutCombLBPotShift)
+{
+    EXPECT_EQ(getVdwKernelType(Nbnxm::KernelType::NotSet,
+                               LJCombinationRule::LorentzBerthelot,
+                               VanDerWaalsType::Cut,
+                               InteractionModifiers::PotShift,
+                               LongRangeVdW::Count),
+              vdwktLJCUT_COMBLB);
+}
+
+TEST(KernelSetupTest, getVdwKernelTypeLjCutCombNoneNone)
+{
+    EXPECT_EQ(getVdwKernelType(Nbnxm::KernelType::NotSet,
+                               LJCombinationRule::None,
+                               VanDerWaalsType::Cut,
+                               InteractionModifiers::None,
+                               LongRangeVdW::Count),
+              vdwktLJCUT_COMBNONE);
+}
+
+TEST(KernelSetupTest, getVdwKernelTypeLjCutCombNonePotShift)
+{
+    EXPECT_EQ(getVdwKernelType(Nbnxm::KernelType::NotSet,
+                               LJCombinationRule::None,
+                               VanDerWaalsType::Cut,
+                               InteractionModifiers::PotShift,
+                               LongRangeVdW::Count),
+              vdwktLJCUT_COMBNONE);
+}
+
+TEST(KernelSetupTest, getVdwKernelTypeLjCutThrows)
+{
+    EXPECT_ANY_THROW(getVdwKernelType(Nbnxm::KernelType::NotSet,
+                                      LJCombinationRule::Count,
+                                      VanDerWaalsType::Cut,
+                                      InteractionModifiers::PotShift,
+                                      LongRangeVdW::Count));
+}
+
+TEST(KernelSetupTest, getVdwKernelTypeLjCutForceSwitch)
+{
+    EXPECT_EQ(getVdwKernelType(Nbnxm::KernelType::NotSet,
+                               LJCombinationRule::None,
+                               VanDerWaalsType::Cut,
+                               InteractionModifiers::ForceSwitch,
+                               LongRangeVdW::Count),
+              vdwktLJFORCESWITCH);
+}
+
+TEST(KernelSetupTest, getVdwKernelTypePmeGeom)
+{
+    EXPECT_EQ(getVdwKernelType(Nbnxm::KernelType::Cpu4x4_PlainC,
+                               LJCombinationRule::None,
+                               VanDerWaalsType::Pme,
+                               InteractionModifiers::Count,
+                               LongRangeVdW::Geom),
+              vdwktLJEWALDCOMBGEOM);
+}
+
+TEST(KernelSetupTest, getVdwKernelTypePmeNone)
+{
+    EXPECT_EQ(getVdwKernelType(Nbnxm::KernelType::Cpu4x4_PlainC,
+                               LJCombinationRule::None,
+                               VanDerWaalsType::Pme,
+                               InteractionModifiers::Count,
+                               LongRangeVdW::Count),
+              vdwktLJEWALDCOMBLB);
+}
+
+TEST(KernelSetupTest, getVdwKernelTypeLjCutPotSwitch)
+{
+    EXPECT_EQ(getVdwKernelType(Nbnxm::KernelType::NotSet,
+                               LJCombinationRule::None,
+                               VanDerWaalsType::Cut,
+                               InteractionModifiers::PotSwitch,
+                               LongRangeVdW::Count),
+              vdwktLJPOTSWITCH);
+}
+
+TEST(KernelSetupTest, getVdwKernelTypeAllCountThrows)
+{
+    // Count cannot be used for VanDerWaalsType or InteractionModifiers because of calls to
+    // enumValueToString(), which require a valid choice to have been made.
+    EXPECT_ANY_THROW(getVdwKernelType(Nbnxm::KernelType::NotSet,
+                                      LJCombinationRule::Count,
+                                      VanDerWaalsType::Cut,
+                                      InteractionModifiers::None,
+                                      LongRangeVdW::Count));
+}
+
 } // namespace
 } // namespace test
 } // namespace gmx