Lift atom type lookup out of inner loops
[alexxy/gromacs.git] / src / gromacs / gmxpreprocess / toppush.cpp
index 9aae7a3926705bd525f38d24f7f0ff7cf57255b5..c73e821844e9b36a2a72a97ad6f73ae9c355b9b7 100644 (file)
@@ -45,6 +45,7 @@
 #include <cstring>
 
 #include <algorithm>
+#include <array>
 #include <string>
 
 #include "gromacs/fileio/warninp.h"
@@ -1722,20 +1723,17 @@ static bool default_cmap_params(gmx::ArrayRef<InteractionsOfType> bondtype,
 /* Returns the number of exact atom type matches, i.e. non wild-card matches,
  * returns -1 when there are no matches at all.
  */
-static int natom_match(const InteractionOfType&      pi,
-                       int                           type_i,
-                       int                           type_j,
-                       int                           type_k,
-                       int                           type_l,
-                       const PreprocessingAtomTypes* atypes)
+static int findNumberOfDihedralAtomMatches(const InteractionOfType&       bondType,
+                                           const gmx::ArrayRef<const int> atomTypes)
 {
-    if ((pi.ai() == -1 || atypes->bondAtomTypeFromAtomType(type_i) == pi.ai())
-        && (pi.aj() == -1 || atypes->bondAtomTypeFromAtomType(type_j) == pi.aj())
-        && (pi.ak() == -1 || atypes->bondAtomTypeFromAtomType(type_k) == pi.ak())
-        && (pi.al() == -1 || atypes->bondAtomTypeFromAtomType(type_l) == pi.al()))
+    GMX_RELEASE_ASSERT(atomTypes.size() == 4, "Dihedrals have 4 atom types");
+    if ((bondType.ai() == -1 || atomTypes[0] == bondType.ai())
+        && (bondType.aj() == -1 || atomTypes[1] == bondType.aj())
+        && (bondType.ak() == -1 || atomTypes[2] == bondType.ak())
+        && (bondType.al() == -1 || atomTypes[3] == bondType.al()))
     {
-        return (pi.ai() == -1 ? 0 : 1) + (pi.aj() == -1 ? 0 : 1) + (pi.ak() == -1 ? 0 : 1)
-               + (pi.al() == -1 ? 0 : 1);
+        return (bondType.ai() == -1 ? 0 : 1) + (bondType.aj() == -1 ? 0 : 1)
+               + (bondType.ak() == -1 ? 0 : 1) + (bondType.al() == -1 ? 0 : 1);
     }
     else
     {
@@ -1743,77 +1741,14 @@ static int natom_match(const InteractionOfType&      pi,
     }
 }
 
-static int findNumberOfDihedralAtomMatches(const InteractionOfType& currentParamFromParameterArray,
-                                           const InteractionOfType& parameterToAdd,
-                                           const t_atoms*           at,
-                                           const PreprocessingAtomTypes* atypes,
-                                           bool                          bB)
-{
-    if (bB)
-    {
-        return natom_match(currentParamFromParameterArray,
-                           at->atom[parameterToAdd.ai()].typeB,
-                           at->atom[parameterToAdd.aj()].typeB,
-                           at->atom[parameterToAdd.ak()].typeB,
-                           at->atom[parameterToAdd.al()].typeB,
-                           atypes);
-    }
-    else
-    {
-        return natom_match(currentParamFromParameterArray,
-                           at->atom[parameterToAdd.ai()].type,
-                           at->atom[parameterToAdd.aj()].type,
-                           at->atom[parameterToAdd.ak()].type,
-                           at->atom[parameterToAdd.al()].type,
-                           atypes);
-    }
-}
-
-static bool findIfAllParameterAtomsMatch(gmx::ArrayRef<const int>      atomsFromParameterArray,
-                                         gmx::ArrayRef<const int>      atomsFromCurrentParameter,
-                                         const t_atoms*                at,
-                                         const PreprocessingAtomTypes* atypes,
-                                         bool                          bB)
-{
-    if (atomsFromParameterArray.size() != atomsFromCurrentParameter.size())
-    {
-        return false;
-    }
-    else if (bB)
-    {
-        for (gmx::index i = 0; i < atomsFromCurrentParameter.ssize(); i++)
-        {
-            if (atypes->bondAtomTypeFromAtomType(at->atom[atomsFromCurrentParameter[i]].typeB)
-                != atomsFromParameterArray[i])
-            {
-                return false;
-            }
-        }
-        return true;
-    }
-    else
-    {
-        for (gmx::index i = 0; i < atomsFromCurrentParameter.ssize(); i++)
-        {
-            if (atypes->bondAtomTypeFromAtomType(at->atom[atomsFromCurrentParameter[i]].type)
-                != atomsFromParameterArray[i])
-            {
-                return false;
-            }
-        }
-        return true;
-    }
-}
-
-static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ftype,
-                                                                          gmx::ArrayRef<InteractionsOfType> bt,
-                                                                          t_atoms* at,
-                                                                          PreprocessingAtomTypes* atypes,
-                                                                          const InteractionOfType& p,
-                                                                          bool bB,
-                                                                          int* nparam_def)
+static std::vector<InteractionOfType>::iterator
+defaultInteractionsOfType(int                               ftype,
+                          gmx::ArrayRef<InteractionsOfType> bondType,
+                          const gmx::ArrayRef<const int>    atomTypes,
+                          int*                              nparam_def)
 {
     int nparam_found = 0;
+
     if (ftype == F_PDIHS || ftype == F_RBDIHS || ftype == F_IDIHS || ftype == F_PIDIHS)
     {
         int nmatch_max = -1;
@@ -1821,24 +1756,23 @@ static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ft
         /* For dihedrals we allow wildcards. We choose the first type
          * that has the most real matches, i.e. non-wildcard matches.
          */
-        auto prevPos = bt[ftype].interactionTypes.end();
-        auto pos     = bt[ftype].interactionTypes.begin();
-        while (pos != bt[ftype].interactionTypes.end() && nmatch_max < 4)
-        {
-            pos = std::find_if(bt[ftype].interactionTypes.begin(),
-                               bt[ftype].interactionTypes.end(),
-                               [&p, &at, &atypes, &bB, &nmatch_max](const auto& param) {
-                                   return (findNumberOfDihedralAtomMatches(param, p, at, atypes, bB)
-                                           > nmatch_max);
+        auto prevPos = bondType[ftype].interactionTypes.end();
+        auto pos     = bondType[ftype].interactionTypes.begin();
+        while (pos != bondType[ftype].interactionTypes.end() && nmatch_max < 4)
+        {
+            pos = std::find_if(bondType[ftype].interactionTypes.begin(),
+                               bondType[ftype].interactionTypes.end(),
+                               [&atomTypes, &nmatch_max](const auto& bondType) {
+                                   return (findNumberOfDihedralAtomMatches(bondType, atomTypes) > nmatch_max);
                                });
-            if (pos != bt[ftype].interactionTypes.end())
+            if (pos != bondType[ftype].interactionTypes.end())
             {
                 prevPos    = pos;
-                nmatch_max = findNumberOfDihedralAtomMatches(*pos, p, at, atypes, bB);
+                nmatch_max = findNumberOfDihedralAtomMatches(*pos, atomTypes);
             }
         }
 
-        if (prevPos != bt[ftype].interactionTypes.end())
+        if (prevPos != bondType[ftype].interactionTypes.end())
         {
             nparam_found++;
 
@@ -1854,7 +1788,7 @@ static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ft
             };
             /* Continue from current iterator position */
             auto       nextPos = prevPos;
-            const auto endIter = bt[ftype].interactionTypes.end();
+            const auto endIter = bondType[ftype].interactionTypes.end();
             safeAdvance(nextPos, 2, endIter);
             for (; nextPos < endIter && bSame; safeAdvance(nextPos, 2, endIter))
             {
@@ -1872,14 +1806,13 @@ static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ft
     }
     else /* Not a dihedral */
     {
-        gmx::ArrayRef<const int> atomParam = p.atoms();
-        auto                     found     = std::find_if(bt[ftype].interactionTypes.begin(),
-                                  bt[ftype].interactionTypes.end(),
-                                  [&atomParam, &at, &atypes, &bB](const auto& param) {
-                                      return findIfAllParameterAtomsMatch(
-                                              param.atoms(), atomParam, at, atypes, bB);
-                                  });
-        if (found != bt[ftype].interactionTypes.end())
+        auto found = std::find_if(
+                bondType[ftype].interactionTypes.begin(),
+                bondType[ftype].interactionTypes.end(),
+                [&atomTypes](const auto& param) {
+                    return std::equal(param.atoms().begin(), param.atoms().end(), atomTypes.begin());
+                });
+        if (found != bondType[ftype].interactionTypes.end())
         {
             nparam_found = 1;
         }
@@ -1912,10 +1845,10 @@ void push_bond(Directive                         d,
     int         nral, nral_fmt, nread, ftype;
     char        format[STRLEN];
     /* One force parameter more, so we can check if we read too many */
-    double cc[MAXFORCEPARAM + 1];
-    int    aa[MAXATOMLIST + 1];
-    bool   bFoundA = FALSE, bFoundB = FALSE, bDef, bSwapParity = FALSE;
-    int    nparam_defA, nparam_defB;
+    double                           cc[MAXFORCEPARAM + 1];
+    std::array<int, MAXATOMLIST + 1> aa;
+    bool                             bFoundA = FALSE, bFoundB = FALSE, bDef, bSwapParity = FALSE;
+    int                              nparam_defA, nparam_defB;
 
     nparam_defA = nparam_defB = 0;
 
@@ -1979,7 +1912,7 @@ void push_bond(Directive                         d,
     }
 
 
-    /* Check for double atoms and atoms out of bounds */
+    /* Check for double atoms and atoms out of bounds, then convert to 0-based indexing */
     for (int i = 0; (i < nral); i++)
     {
         if (aa[i] < 1 || aa[i] > at->nr)
@@ -2018,21 +1951,33 @@ void push_bond(Directive                         d,
                 }
             }
         }
+
+        // Convert to 0-based indexing
+        --aa[i];
     }
 
+    // These are the atom indices for this interaction
+    gmx::ArrayRef<int> atomIndices(aa.begin(), aa.begin() + nral);
+
+    // Look up the A-state atom types for this interaction
+    std::vector<int> atomTypes(atomIndices.size());
+    std::transform(atomIndices.begin(), atomIndices.end(), atomTypes.begin(), [at, atypes](const int atomIndex) {
+        return atypes->bondAtomTypeFromAtomType(at->atom[atomIndex].type).value();
+    });
+    // Look up the B-state atom types for this interaction
+    std::vector<int> atomTypesB(atomIndices.size());
+    std::transform(atomIndices.begin(), atomIndices.end(), atomTypesB.begin(), [at, atypes](const int atomIndex) {
+        return atypes->bondAtomTypeFromAtomType(at->atom[atomIndex].typeB).value();
+    });
+
     /* default force parameters  */
-    std::vector<int> atoms;
-    for (int j = 0; (j < nral); j++)
-    {
-        atoms.emplace_back(aa[j] - 1);
-    }
     /* need to have an empty but initialized param array for some reason */
     std::array<real, MAXFORCEPARAM> forceParam = { 0.0 };
 
     /* Get force params for normal and free energy perturbation
      * studies, as determined by types!
      */
-    InteractionOfType param(atoms, forceParam, "");
+    InteractionOfType param(atomIndices, forceParam, "");
 
     std::vector<InteractionOfType>::iterator foundAParameter = bondtype[ftype].interactionTypes.end();
     std::vector<InteractionOfType>::iterator foundBParameter = bondtype[ftype].interactionTypes.end();
@@ -2044,8 +1989,7 @@ void push_bond(Directive                         d,
         }
         else
         {
-            foundAParameter =
-                    defaultInteractionsOfType(ftype, bondtype, at, atypes, param, FALSE, &nparam_defA);
+            foundAParameter = defaultInteractionsOfType(ftype, bondtype, atomTypes, &nparam_defA);
             if (foundAParameter != bondtype[ftype].interactionTypes.end())
             {
                 /* Copy the A-state and B-state default parameters. */
@@ -2066,8 +2010,7 @@ void push_bond(Directive                         d,
         }
         else
         {
-            foundBParameter =
-                    defaultInteractionsOfType(ftype, bondtype, at, atypes, param, TRUE, &nparam_defB);
+            foundBParameter = defaultInteractionsOfType(ftype, bondtype, atomTypesB, &nparam_defB);
             if (foundBParameter != bondtype[ftype].interactionTypes.end())
             {
                 /* Copy only the B-state default parameters */