Refactor analysis nbsearch tests
authorTeemu Murtola <teemu.murtola@gmail.com>
Tue, 22 Aug 2017 19:09:49 +0000 (22:09 +0300)
committerDavid van der Spoel <davidvanderspoel@gmail.com>
Mon, 4 Sep 2017 08:38:37 +0000 (10:38 +0200)
Remove assumptions from the tests about the order in which pairs are
returned.  Prepares the tests for an all-pairs search from a single set
of positions, where the order is not as predictable.  And the test code
is actually easier to understand this way.

Change-Id: Id33eaff1c4c7f94a26099c6d4e34e7e008c1afa4

src/gromacs/selection/tests/nbsearch.cpp

index af445d88c35ed74fae099a3a039c36fe1ae9e216..3a5f8751544b1790fb8de54c2becd4378c7f9f94 100644 (file)
@@ -52,6 +52,7 @@
 
 #include <algorithm>
 #include <limits>
+#include <map>
 #include <numeric>
 #include <vector>
 
@@ -265,39 +266,38 @@ void NeighborhoodSearchTestData::computeReferencesInternal(t_pbc *pbc, bool bXY)
     {
         cutoff = std::numeric_limits<real>::max();
     }
-    TestPositionList::iterator i;
-    for (i = testPositions_.begin(); i != testPositions_.end(); ++i)
+    for (TestPosition &testPos : testPositions_)
     {
-        i->refMinDist      = cutoff;
-        i->refNearestPoint = -1;
-        i->refPairs.clear();
+        testPos.refMinDist      = cutoff;
+        testPos.refNearestPoint = -1;
+        testPos.refPairs.clear();
         for (int j = 0; j < refPosCount_; ++j)
         {
             rvec dx;
             if (pbc != nullptr)
             {
-                pbc_dx(pbc, i->x, refPos_[j], dx);
+                pbc_dx(pbc, testPos.x, refPos_[j], dx);
             }
             else
             {
-                rvec_sub(i->x, refPos_[j], dx);
+                rvec_sub(testPos.x, refPos_[j], dx);
             }
             // TODO: This may not work intuitively for 2D with the third box
             // vector not parallel to the Z axis, but neither does the actual
             // neighborhood search.
             const real dist =
                 !bXY ? norm(dx) : std::hypot(dx[XX], dx[YY]);
-            if (dist < i->refMinDist)
+            if (dist < testPos.refMinDist)
             {
-                i->refMinDist      = dist;
-                i->refNearestPoint = j;
+                testPos.refMinDist      = dist;
+                testPos.refNearestPoint = j;
             }
             if (dist <= cutoff)
             {
                 RefPair pair(j, dist);
-                GMX_RELEASE_ASSERT(i->refPairs.empty() || i->refPairs.back() < pair,
+                GMX_RELEASE_ASSERT(testPos.refPairs.empty() || testPos.refPairs.back() < pair,
                                    "Reference pairs should be generated in sorted order");
-                i->refPairs.push_back(pair);
+                testPos.refPairs.push_back(pair);
             }
         }
     }
@@ -544,89 +544,83 @@ void NeighborhoodSearchTest::testPairSearchFull(
         const gmx::ConstArrayRef<int>            &refIndices,
         const gmx::ConstArrayRef<int>            &testIndices)
 {
+    std::map<int, RefPairList> refPairs;
     // TODO: Some parts of this code do not work properly if pos does not
     // initially contain all the test positions.
-    std::set<int> remainingTestPositions;
     gmx::AnalysisNeighborhoodPositions  posCopy(pos);
     if (testIndices.empty())
     {
         for (size_t i = 0; i < data.testPositions_.size(); ++i)
         {
-            remainingTestPositions.insert(i);
+            refPairs[i] = data.testPositions_[i].refPairs;
         }
     }
     else
     {
-        remainingTestPositions.insert(testIndices.begin(), testIndices.end());
+        for (int index : testIndices)
+        {
+            refPairs[index] = data.testPositions_[index].refPairs;
+        }
         posCopy.indexed(testIndices);
     }
-
-    gmx::AnalysisNeighborhoodPairSearch pairSearch
-        = search->startPairSearch(posCopy);
-    gmx::AnalysisNeighborhoodPair       pair;
-    // There is an ordering assumption here that all pairs for a test position
-    // are returned consencutively; with the current optimizations in the
-    // search code, this is reasoable, as the set of grid cell pairs searched
-    // depends on the test position.
-    RefPairList refPairs;
-    int         prevTestPos = -1;
-    while (pairSearch.findNextPair(&pair))
+    if (excls != nullptr)
     {
-        const int testIndex =
-            (testIndices.empty() ? pair.testIndex() : testIndices[pair.testIndex()]);
-        const int refIndex =
-            (refIndices.empty() ? pair.refIndex() : refIndices[pair.refIndex()]);
-        if (testIndex != prevTestPos)
+        for (auto &entry : refPairs)
         {
-            if (prevTestPos != -1)
-            {
-                checkAllPairsFound(refPairs, data.refPos_, prevTestPos,
-                                   data.testPositions_[prevTestPos].x);
-            }
-            if (remainingTestPositions.count(testIndex) == 0)
-            {
-                ADD_FAILURE()
-                << "Pairs for test position " << testIndex
-                << " are returned more than once.";
-            }
-            remainingTestPositions.erase(testIndex);
-            refPairs = data.testPositions_[testIndex].refPairs;
-            if (excls != nullptr)
+            const int testIndex = entry.first;
+            ExclusionsHelper::markExcludedPairs(&entry.second, testIndex, excls);
+        }
+    }
+    if (!refIndices.empty())
+    {
+        for (auto &entry : refPairs)
+        {
+            for (auto &refPair : entry.second)
             {
-                ExclusionsHelper::markExcludedPairs(&refPairs, testIndex, excls);
+                refPair.bIndexed = false;
             }
-            if (!refIndices.empty())
+            for (int index : refIndices)
             {
-                RefPairList::iterator refPair;
-                for (refPair = refPairs.begin(); refPair != refPairs.end(); ++refPair)
+                NeighborhoodSearchTestData::RefPair searchPair(index, 0.0);
+                auto refPair = std::lower_bound(entry.second.begin(), entry.second.end(), searchPair);
+                if (refPair != entry.second.end() && refPair->refIndex == index)
                 {
-                    refPair->bIndexed = false;
+                    refPair->bIndexed = true;
                 }
-                for (size_t i = 0; i < refIndices.size(); ++i)
-                {
-                    NeighborhoodSearchTestData::RefPair searchPair(refIndices[i], 0.0);
-                    refPair = std::lower_bound(refPairs.begin(), refPairs.end(), searchPair);
-                    if (refPair != refPairs.end() && refPair->refIndex == refIndices[i])
-                    {
-                        refPair->bIndexed = true;
-                    }
-                }
-                for (refPair = refPairs.begin(); refPair != refPairs.end(); ++refPair)
+            }
+            for (auto &refPair : entry.second)
+            {
+                if (!refPair.bIndexed)
                 {
-                    if (!refPair->bIndexed)
-                    {
-                        refPair->bFound = true;
-                    }
+                    refPair.bFound = true;
                 }
             }
-            prevTestPos = testIndex;
         }
+    }
 
+    gmx::AnalysisNeighborhoodPairSearch pairSearch
+        = search->startPairSearch(posCopy);
+    gmx::AnalysisNeighborhoodPair       pair;
+    while (pairSearch.findNextPair(&pair))
+    {
+        const int testIndex =
+            (testIndices.empty() ? pair.testIndex() : testIndices[pair.testIndex()]);
+        const int refIndex =
+            (refIndices.empty() ? pair.refIndex() : refIndices[pair.refIndex()]);
+
+        if (refPairs.count(testIndex) == 0)
+        {
+            ADD_FAILURE()
+            << "Expected: No pairs are returned for test position " << testIndex << ".\n"
+            << "  Actual: Pair with ref " << refIndex << " is returned.";
+            continue;
+        }
         NeighborhoodSearchTestData::RefPair searchPair(refIndex,
                                                        std::sqrt(pair.distance2()));
-        RefPairList::iterator               foundRefPair
-            = std::lower_bound(refPairs.begin(), refPairs.end(), searchPair);
-        if (foundRefPair == refPairs.end() || foundRefPair->refIndex != refIndex)
+        const auto foundRefPair
+            = std::lower_bound(refPairs[testIndex].begin(), refPairs[testIndex].end(),
+                               searchPair);
+        if (foundRefPair == refPairs[testIndex].end() || foundRefPair->refIndex != refIndex)
         {
             ADD_FAILURE()
             << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
@@ -664,38 +658,11 @@ void NeighborhoodSearchTest::testPairSearchFull(
         }
     }
 
-    checkAllPairsFound(refPairs, data.refPos_, prevTestPos,
-                       data.testPositions_[prevTestPos].x);
-
-    std::set<int> refPositions(refIndices.begin(), refIndices.end());
-
-    for (std::set<int>::const_iterator i = remainingTestPositions.begin();
-         i != remainingTestPositions.end(); ++i)
+    for (auto &entry : refPairs)
     {
-        // Account for the case where the i particle is listed in the testIndex,
-        // but none of its ref neighbours were listed in the refIndex.
-        if (!refIndices.empty())
-        {
-            RefPairList::const_iterator refPair;
-            bool foundAnyRefInIndex = false;
-
-            for (refPair = data.testPositions_[*i].refPairs.begin();
-                 refPair != data.testPositions_[*i].refPairs.end() && !foundAnyRefInIndex; ++refPair)
-            {
-                foundAnyRefInIndex = (refPositions.count(refPair->refIndex) > 0);
-            }
-            if (!foundAnyRefInIndex)
-            {
-                continue;
-            }
-        }
-        if (!data.testPositions_[*i].refPairs.empty())
-        {
-            ADD_FAILURE()
-            << "Expected: Pairs would be returned for test position " << *i << ".\n"
-            << "  Actual: None were returned.";
-            break;
-        }
+        const int testIndex = entry.first;
+        checkAllPairsFound(entry.second, data.refPos_, testIndex,
+                           data.testPositions_[testIndex].x);
     }
 }