Lift atom type lookup out of inner loops
authorMark Abraham <mark.j.abraham@gmail.com>
Mon, 26 Jul 2021 08:31:26 +0000 (10:31 +0200)
committerArtem Zhmurov <zhmurov@gmail.com>
Thu, 5 Aug 2021 14:10:24 +0000 (14:10 +0000)
Grompp loops over molecule types, looking up force parameters for each
interaction from the associated bond types for the system (e.g. from
the force field). The atom types for that interaction have to be
looked up from the atoms for the molecule type that contains it, but
this should be done only once, before considering each bond type as a
possible match. The lookups for both A- and B-state parameters are now
lifted out of the loops over bond types, simplifying the logic and
significantly improving performance.

Once that is done, one custom function could be replaced by
std::equal.

Improved some variable naming

Apply 1 suggestion(s) to 1 file(s)

src/gromacs/gmxpreprocess/toppush.cpp

index 9aae7a3926705bd525f38d24f7f0ff7cf57255b5..c73e821844e9b36a2a72a97ad6f73ae9c355b9b7 100644 (file)
@@ -45,6 +45,7 @@
 #include <cstring>
 
 #include <algorithm>
+#include <array>
 #include <string>
 
 #include "gromacs/fileio/warninp.h"
@@ -1722,20 +1723,17 @@ static bool default_cmap_params(gmx::ArrayRef<InteractionsOfType> bondtype,
 /* Returns the number of exact atom type matches, i.e. non wild-card matches,
  * returns -1 when there are no matches at all.
  */
-static int natom_match(const InteractionOfType&      pi,
-                       int                           type_i,
-                       int                           type_j,
-                       int                           type_k,
-                       int                           type_l,
-                       const PreprocessingAtomTypes* atypes)
+static int findNumberOfDihedralAtomMatches(const InteractionOfType&       bondType,
+                                           const gmx::ArrayRef<const int> atomTypes)
 {
-    if ((pi.ai() == -1 || atypes->bondAtomTypeFromAtomType(type_i) == pi.ai())
-        && (pi.aj() == -1 || atypes->bondAtomTypeFromAtomType(type_j) == pi.aj())
-        && (pi.ak() == -1 || atypes->bondAtomTypeFromAtomType(type_k) == pi.ak())
-        && (pi.al() == -1 || atypes->bondAtomTypeFromAtomType(type_l) == pi.al()))
+    GMX_RELEASE_ASSERT(atomTypes.size() == 4, "Dihedrals have 4 atom types");
+    if ((bondType.ai() == -1 || atomTypes[0] == bondType.ai())
+        && (bondType.aj() == -1 || atomTypes[1] == bondType.aj())
+        && (bondType.ak() == -1 || atomTypes[2] == bondType.ak())
+        && (bondType.al() == -1 || atomTypes[3] == bondType.al()))
     {
-        return (pi.ai() == -1 ? 0 : 1) + (pi.aj() == -1 ? 0 : 1) + (pi.ak() == -1 ? 0 : 1)
-               + (pi.al() == -1 ? 0 : 1);
+        return (bondType.ai() == -1 ? 0 : 1) + (bondType.aj() == -1 ? 0 : 1)
+               + (bondType.ak() == -1 ? 0 : 1) + (bondType.al() == -1 ? 0 : 1);
     }
     else
     {
@@ -1743,77 +1741,14 @@ static int natom_match(const InteractionOfType&      pi,
     }
 }
 
-static int findNumberOfDihedralAtomMatches(const InteractionOfType& currentParamFromParameterArray,
-                                           const InteractionOfType& parameterToAdd,
-                                           const t_atoms*           at,
-                                           const PreprocessingAtomTypes* atypes,
-                                           bool                          bB)
-{
-    if (bB)
-    {
-        return natom_match(currentParamFromParameterArray,
-                           at->atom[parameterToAdd.ai()].typeB,
-                           at->atom[parameterToAdd.aj()].typeB,
-                           at->atom[parameterToAdd.ak()].typeB,
-                           at->atom[parameterToAdd.al()].typeB,
-                           atypes);
-    }
-    else
-    {
-        return natom_match(currentParamFromParameterArray,
-                           at->atom[parameterToAdd.ai()].type,
-                           at->atom[parameterToAdd.aj()].type,
-                           at->atom[parameterToAdd.ak()].type,
-                           at->atom[parameterToAdd.al()].type,
-                           atypes);
-    }
-}
-
-static bool findIfAllParameterAtomsMatch(gmx::ArrayRef<const int>      atomsFromParameterArray,
-                                         gmx::ArrayRef<const int>      atomsFromCurrentParameter,
-                                         const t_atoms*                at,
-                                         const PreprocessingAtomTypes* atypes,
-                                         bool                          bB)
-{
-    if (atomsFromParameterArray.size() != atomsFromCurrentParameter.size())
-    {
-        return false;
-    }
-    else if (bB)
-    {
-        for (gmx::index i = 0; i < atomsFromCurrentParameter.ssize(); i++)
-        {
-            if (atypes->bondAtomTypeFromAtomType(at->atom[atomsFromCurrentParameter[i]].typeB)
-                != atomsFromParameterArray[i])
-            {
-                return false;
-            }
-        }
-        return true;
-    }
-    else
-    {
-        for (gmx::index i = 0; i < atomsFromCurrentParameter.ssize(); i++)
-        {
-            if (atypes->bondAtomTypeFromAtomType(at->atom[atomsFromCurrentParameter[i]].type)
-                != atomsFromParameterArray[i])
-            {
-                return false;
-            }
-        }
-        return true;
-    }
-}
-
-static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ftype,
-                                                                          gmx::ArrayRef<InteractionsOfType> bt,
-                                                                          t_atoms* at,
-                                                                          PreprocessingAtomTypes* atypes,
-                                                                          const InteractionOfType& p,
-                                                                          bool bB,
-                                                                          int* nparam_def)
+static std::vector<InteractionOfType>::iterator
+defaultInteractionsOfType(int                               ftype,
+                          gmx::ArrayRef<InteractionsOfType> bondType,
+                          const gmx::ArrayRef<const int>    atomTypes,
+                          int*                              nparam_def)
 {
     int nparam_found = 0;
+
     if (ftype == F_PDIHS || ftype == F_RBDIHS || ftype == F_IDIHS || ftype == F_PIDIHS)
     {
         int nmatch_max = -1;
@@ -1821,24 +1756,23 @@ static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ft
         /* For dihedrals we allow wildcards. We choose the first type
          * that has the most real matches, i.e. non-wildcard matches.
          */
-        auto prevPos = bt[ftype].interactionTypes.end();
-        auto pos     = bt[ftype].interactionTypes.begin();
-        while (pos != bt[ftype].interactionTypes.end() && nmatch_max < 4)
-        {
-            pos = std::find_if(bt[ftype].interactionTypes.begin(),
-                               bt[ftype].interactionTypes.end(),
-                               [&p, &at, &atypes, &bB, &nmatch_max](const auto& param) {
-                                   return (findNumberOfDihedralAtomMatches(param, p, at, atypes, bB)
-                                           > nmatch_max);
+        auto prevPos = bondType[ftype].interactionTypes.end();
+        auto pos     = bondType[ftype].interactionTypes.begin();
+        while (pos != bondType[ftype].interactionTypes.end() && nmatch_max < 4)
+        {
+            pos = std::find_if(bondType[ftype].interactionTypes.begin(),
+                               bondType[ftype].interactionTypes.end(),
+                               [&atomTypes, &nmatch_max](const auto& bondType) {
+                                   return (findNumberOfDihedralAtomMatches(bondType, atomTypes) > nmatch_max);
                                });
-            if (pos != bt[ftype].interactionTypes.end())
+            if (pos != bondType[ftype].interactionTypes.end())
             {
                 prevPos    = pos;
-                nmatch_max = findNumberOfDihedralAtomMatches(*pos, p, at, atypes, bB);
+                nmatch_max = findNumberOfDihedralAtomMatches(*pos, atomTypes);
             }
         }
 
-        if (prevPos != bt[ftype].interactionTypes.end())
+        if (prevPos != bondType[ftype].interactionTypes.end())
         {
             nparam_found++;
 
@@ -1854,7 +1788,7 @@ static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ft
             };
             /* Continue from current iterator position */
             auto       nextPos = prevPos;
-            const auto endIter = bt[ftype].interactionTypes.end();
+            const auto endIter = bondType[ftype].interactionTypes.end();
             safeAdvance(nextPos, 2, endIter);
             for (; nextPos < endIter && bSame; safeAdvance(nextPos, 2, endIter))
             {
@@ -1872,14 +1806,13 @@ static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ft
     }
     else /* Not a dihedral */
     {
-        gmx::ArrayRef<const int> atomParam = p.atoms();
-        auto                     found     = std::find_if(bt[ftype].interactionTypes.begin(),
-                                  bt[ftype].interactionTypes.end(),
-                                  [&atomParam, &at, &atypes, &bB](const auto& param) {
-                                      return findIfAllParameterAtomsMatch(
-                                              param.atoms(), atomParam, at, atypes, bB);
-                                  });
-        if (found != bt[ftype].interactionTypes.end())
+        auto found = std::find_if(
+                bondType[ftype].interactionTypes.begin(),
+                bondType[ftype].interactionTypes.end(),
+                [&atomTypes](const auto& param) {
+                    return std::equal(param.atoms().begin(), param.atoms().end(), atomTypes.begin());
+                });
+        if (found != bondType[ftype].interactionTypes.end())
         {
             nparam_found = 1;
         }
@@ -1912,10 +1845,10 @@ void push_bond(Directive                         d,
     int         nral, nral_fmt, nread, ftype;
     char        format[STRLEN];
     /* One force parameter more, so we can check if we read too many */
-    double cc[MAXFORCEPARAM + 1];
-    int    aa[MAXATOMLIST + 1];
-    bool   bFoundA = FALSE, bFoundB = FALSE, bDef, bSwapParity = FALSE;
-    int    nparam_defA, nparam_defB;
+    double                           cc[MAXFORCEPARAM + 1];
+    std::array<int, MAXATOMLIST + 1> aa;
+    bool                             bFoundA = FALSE, bFoundB = FALSE, bDef, bSwapParity = FALSE;
+    int                              nparam_defA, nparam_defB;
 
     nparam_defA = nparam_defB = 0;
 
@@ -1979,7 +1912,7 @@ void push_bond(Directive                         d,
     }
 
 
-    /* Check for double atoms and atoms out of bounds */
+    /* Check for double atoms and atoms out of bounds, then convert to 0-based indexing */
     for (int i = 0; (i < nral); i++)
     {
         if (aa[i] < 1 || aa[i] > at->nr)
@@ -2018,21 +1951,33 @@ void push_bond(Directive                         d,
                 }
             }
         }
+
+        // Convert to 0-based indexing
+        --aa[i];
     }
 
+    // These are the atom indices for this interaction
+    gmx::ArrayRef<int> atomIndices(aa.begin(), aa.begin() + nral);
+
+    // Look up the A-state atom types for this interaction
+    std::vector<int> atomTypes(atomIndices.size());
+    std::transform(atomIndices.begin(), atomIndices.end(), atomTypes.begin(), [at, atypes](const int atomIndex) {
+        return atypes->bondAtomTypeFromAtomType(at->atom[atomIndex].type).value();
+    });
+    // Look up the B-state atom types for this interaction
+    std::vector<int> atomTypesB(atomIndices.size());
+    std::transform(atomIndices.begin(), atomIndices.end(), atomTypesB.begin(), [at, atypes](const int atomIndex) {
+        return atypes->bondAtomTypeFromAtomType(at->atom[atomIndex].typeB).value();
+    });
+
     /* default force parameters  */
-    std::vector<int> atoms;
-    for (int j = 0; (j < nral); j++)
-    {
-        atoms.emplace_back(aa[j] - 1);
-    }
     /* need to have an empty but initialized param array for some reason */
     std::array<real, MAXFORCEPARAM> forceParam = { 0.0 };
 
     /* Get force params for normal and free energy perturbation
      * studies, as determined by types!
      */
-    InteractionOfType param(atoms, forceParam, "");
+    InteractionOfType param(atomIndices, forceParam, "");
 
     std::vector<InteractionOfType>::iterator foundAParameter = bondtype[ftype].interactionTypes.end();
     std::vector<InteractionOfType>::iterator foundBParameter = bondtype[ftype].interactionTypes.end();
@@ -2044,8 +1989,7 @@ void push_bond(Directive                         d,
         }
         else
         {
-            foundAParameter =
-                    defaultInteractionsOfType(ftype, bondtype, at, atypes, param, FALSE, &nparam_defA);
+            foundAParameter = defaultInteractionsOfType(ftype, bondtype, atomTypes, &nparam_defA);
             if (foundAParameter != bondtype[ftype].interactionTypes.end())
             {
                 /* Copy the A-state and B-state default parameters. */
@@ -2066,8 +2010,7 @@ void push_bond(Directive                         d,
         }
         else
         {
-            foundBParameter =
-                    defaultInteractionsOfType(ftype, bondtype, at, atypes, param, TRUE, &nparam_defB);
+            foundBParameter = defaultInteractionsOfType(ftype, bondtype, atomTypesB, &nparam_defB);
             if (foundBParameter != bondtype[ftype].interactionTypes.end())
             {
                 /* Copy only the B-state default parameters */