Merge release-5-0 into master
[alexxy/gromacs.git] / src / gromacs / selection / tests / nbsearch.cpp
index 177913fad1212ed2c34808fe16dd744b31b516ef..1d89772d2a197883fb94864b6cec07c1390821da 100644 (file)
  * \author Teemu Murtola <teemu.murtola@gmail.com>
  * \ingroup module_selection
  */
+#include "gmxpre.h"
+
+#include "gromacs/selection/nbsearch.h"
+
 #include <gtest/gtest.h>
 
 #include <cmath>
 
+#include <algorithm>
 #include <limits>
-#include <set>
+#include <numeric>
 #include <vector>
 
-#include "gromacs/legacyheaders/pbc.h"
-#include "gromacs/legacyheaders/vec.h"
-
-#include "gromacs/selection/nbsearch.h"
+#include "gromacs/math/vec.h"
+#include "gromacs/pbcutil/pbc.h"
 #include "gromacs/random/random.h"
+#include "gromacs/topology/block.h"
 #include "gromacs/utility/smalloc.h"
 
 #include "testutils/testasserts.h"
@@ -71,6 +75,29 @@ namespace
 class NeighborhoodSearchTestData
 {
     public:
+        struct RefPair
+        {
+            RefPair(int refIndex, real distance)
+                : refIndex(refIndex), distance(distance), bFound(false),
+                  bExcluded(false)
+            {
+            }
+
+            bool operator<(const RefPair &other) const
+            {
+                return refIndex < other.refIndex;
+            }
+
+            int                 refIndex;
+            real                distance;
+            // The variables below are state variables that are only used
+            // during the actual testing after creating a copy of the reference
+            // pair list, not as part of the reference data.
+            // Simpler to have just a single structure for both purposes.
+            bool                bFound;
+            bool                bExcluded;
+        };
+
         struct TestPosition
         {
             explicit TestPosition(const rvec x)
@@ -79,11 +106,12 @@ class NeighborhoodSearchTestData
                 copy_rvec(x, this->x);
             }
 
-            rvec                x;
-            real                refMinDist;
-            int                 refNearestPoint;
-            std::set<int>       refPairs;
+            rvec                 x;
+            real                 refMinDist;
+            int                  refNearestPoint;
+            std::vector<RefPair> refPairs;
         };
+
         typedef std::vector<TestPosition> TestPositionList;
 
         NeighborhoodSearchTestData(int seed, real cutoff);
@@ -120,7 +148,26 @@ class NeighborhoodSearchTestData
         void generateRandomPosition(rvec x);
         void generateRandomRefPositions(int count);
         void generateRandomTestPositions(int count);
-        void computeReferences(t_pbc *pbc);
+        void computeReferences(t_pbc *pbc)
+        {
+            computeReferencesInternal(pbc, false);
+        }
+        void computeReferencesXY(t_pbc *pbc)
+        {
+            computeReferencesInternal(pbc, true);
+        }
+
+        bool containsPair(int testIndex, const RefPair &pair) const
+        {
+            const std::vector<RefPair>          &refPairs = testPositions_[testIndex].refPairs;
+            std::vector<RefPair>::const_iterator foundRefPair
+                = std::lower_bound(refPairs.begin(), refPairs.end(), pair);
+            if (foundRefPair == refPairs.end() || foundRefPair->refIndex != pair.refIndex)
+            {
+                return false;
+            }
+            return true;
+        }
 
         gmx_rng_t                        rng_;
         real                             cutoff_;
@@ -131,9 +178,14 @@ class NeighborhoodSearchTestData
         TestPositionList                 testPositions_;
 
     private:
+        void computeReferencesInternal(t_pbc *pbc, bool bXY);
+
         mutable rvec                    *testPos_;
 };
 
+//! Shorthand for a collection of reference pairs.
+typedef std::vector<NeighborhoodSearchTestData::RefPair> RefPairList;
+
 NeighborhoodSearchTestData::NeighborhoodSearchTestData(int seed, real cutoff)
     : rng_(NULL), cutoff_(cutoff), refPosCount_(0), refPos_(NULL), testPos_(NULL)
 {
@@ -187,7 +239,7 @@ void NeighborhoodSearchTestData::generateRandomTestPositions(int count)
     }
 }
 
-void NeighborhoodSearchTestData::computeReferences(t_pbc *pbc)
+void NeighborhoodSearchTestData::computeReferencesInternal(t_pbc *pbc, bool bXY)
 {
     real cutoff = cutoff_;
     if (cutoff <= 0)
@@ -211,7 +263,11 @@ void NeighborhoodSearchTestData::computeReferences(t_pbc *pbc)
             {
                 rvec_sub(i->x, refPos_[j], dx);
             }
-            const real dist = norm(dx);
+            // TODO: This may not work intuitively for 2D with the third box
+            // vector not parallel to the Z axis, but neither does the actual
+            // neighborhood search.
+            const real dist =
+                !bXY ? norm(dx) : sqrt(sqr(dx[XX]) + sqr(dx[YY]));
             if (dist < i->refMinDist)
             {
                 i->refMinDist      = dist;
@@ -219,12 +275,114 @@ void NeighborhoodSearchTestData::computeReferences(t_pbc *pbc)
             }
             if (dist <= cutoff)
             {
-                i->refPairs.insert(j);
+                RefPair pair(j, dist);
+                GMX_RELEASE_ASSERT(i->refPairs.empty() || i->refPairs.back() < pair,
+                                   "Reference pairs should be generated in sorted order");
+                i->refPairs.push_back(pair);
             }
         }
     }
 }
 
+/********************************************************************
+ * ExclusionsHelper
+ */
+
+class ExclusionsHelper
+{
+    public:
+        static void markExcludedPairs(RefPairList *refPairs, int testIndex,
+                                      const t_blocka *excls);
+
+        ExclusionsHelper(int refPosCount, int testPosCount);
+
+        void generateExclusions();
+
+        const t_blocka *exclusions() const { return &excls_; }
+
+        gmx::ConstArrayRef<int> refPosIds() const
+        {
+            return gmx::constArrayRefFromVector<int>(exclusionIds_.begin(),
+                                                     exclusionIds_.begin() + refPosCount_);
+        }
+        gmx::ConstArrayRef<int> testPosIds() const
+        {
+            return gmx::constArrayRefFromVector<int>(exclusionIds_.begin(),
+                                                     exclusionIds_.begin() + testPosCount_);
+        }
+
+    private:
+        int              refPosCount_;
+        int              testPosCount_;
+        std::vector<int> exclusionIds_;
+        std::vector<int> exclsIndex_;
+        std::vector<int> exclsAtoms_;
+        t_blocka         excls_;
+};
+
+// static
+void ExclusionsHelper::markExcludedPairs(RefPairList *refPairs, int testIndex,
+                                         const t_blocka *excls)
+{
+    int count = 0;
+    for (int i = excls->index[testIndex]; i < excls->index[testIndex + 1]; ++i)
+    {
+        const int                           excludedIndex = excls->a[i];
+        NeighborhoodSearchTestData::RefPair searchPair(excludedIndex, 0.0);
+        RefPairList::iterator               excludedRefPair
+            = std::lower_bound(refPairs->begin(), refPairs->end(), searchPair);
+        if (excludedRefPair != refPairs->end()
+            && excludedRefPair->refIndex == excludedIndex)
+        {
+            excludedRefPair->bFound    = true;
+            excludedRefPair->bExcluded = true;
+            ++count;
+        }
+    }
+}
+
+ExclusionsHelper::ExclusionsHelper(int refPosCount, int testPosCount)
+    : refPosCount_(refPosCount), testPosCount_(testPosCount)
+{
+    // Generate an array of 0, 1, 2, ...
+    // TODO: Make the tests work also with non-trivial exclusion IDs,
+    // and test that.
+    exclusionIds_.resize(std::max(refPosCount, testPosCount), 1);
+    exclusionIds_[0] = 0;
+    std::partial_sum(exclusionIds_.begin(), exclusionIds_.end(),
+                     exclusionIds_.begin());
+
+    excls_.nr           = 0;
+    excls_.index        = NULL;
+    excls_.nra          = 0;
+    excls_.a            = NULL;
+    excls_.nalloc_index = 0;
+    excls_.nalloc_a     = 0;
+}
+
+void ExclusionsHelper::generateExclusions()
+{
+    // TODO: Consider a better set of test data, where the density of the
+    // particles would be higher, or where the exclusions would not be random,
+    // to make a higher percentage of the exclusions to actually be within the
+    // cutoff.
+    exclsIndex_.reserve(testPosCount_ + 1);
+    exclsAtoms_.reserve(testPosCount_ * 20);
+    exclsIndex_.push_back(0);
+    for (int i = 0; i < testPosCount_; ++i)
+    {
+        for (int j = 0; j < 20; ++j)
+        {
+            exclsAtoms_.push_back(i + j*3);
+        }
+        exclsIndex_.push_back(exclsAtoms_.size());
+    }
+    excls_.nr    = exclsIndex_.size();
+    excls_.index = &exclsIndex_[0];
+    excls_.nra   = exclsAtoms_.size();
+    excls_.a     = &exclsAtoms_[0];
+}
+
 /********************************************************************
  * NeighborhoodSearchTest
  */
@@ -240,6 +398,10 @@ class NeighborhoodSearchTest : public ::testing::Test
                               const NeighborhoodSearchTestData &data);
         void testPairSearch(gmx::AnalysisNeighborhoodSearch  *search,
                             const NeighborhoodSearchTestData &data);
+        void testPairSearchFull(gmx::AnalysisNeighborhoodSearch          *search,
+                                const NeighborhoodSearchTestData         &data,
+                                const gmx::AnalysisNeighborhoodPositions &pos,
+                                const t_blocka                           *excls);
 
         gmx::AnalysisNeighborhood        nb_;
 };
@@ -282,6 +444,8 @@ void NeighborhoodSearchTest::testNearestPoint(
         {
             EXPECT_EQ(i->refNearestPoint, pair.refIndex());
             EXPECT_EQ(0, pair.testIndex());
+            EXPECT_REAL_EQ_TOL(i->refMinDist, sqrt(pair.distance2()),
+                               gmx::test::ulpTolerance(64));
         }
         else
         {
@@ -290,31 +454,119 @@ void NeighborhoodSearchTest::testNearestPoint(
     }
 }
 
+/*! \brief
+ * Helper function to check that all expected pairs were found.
+ */
+static void checkAllPairsFound(const RefPairList &refPairs)
+{
+    // This could be elegantly expressed with Google Mock matchers, but that
+    // has a significant effect on the runtime of the tests...
+    for (RefPairList::const_iterator i = refPairs.begin(); i != refPairs.end(); ++i)
+    {
+        if (!i->bFound)
+        {
+            ADD_FAILURE()
+            << "Some pairs within the cutoff were not found.";
+            break;
+        }
+    }
+}
+
 void NeighborhoodSearchTest::testPairSearch(
         gmx::AnalysisNeighborhoodSearch  *search,
         const NeighborhoodSearchTestData &data)
 {
-    NeighborhoodSearchTestData::TestPositionList::const_iterator i;
-    for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
+    testPairSearchFull(search, data, data.testPositions(), NULL);
+}
+
+void NeighborhoodSearchTest::testPairSearchFull(
+        gmx::AnalysisNeighborhoodSearch          *search,
+        const NeighborhoodSearchTestData         &data,
+        const gmx::AnalysisNeighborhoodPositions &pos,
+        const t_blocka                           *excls)
+{
+    // TODO: Some parts of this code do not work properly if pos does not
+    // contain all the test positions.
+    std::set<int> remainingTestPositions;
+    for (size_t i = 0; i < data.testPositions_.size(); ++i)
     {
-        std::set<int> checkSet                         = i->refPairs;
-        gmx::AnalysisNeighborhoodPairSearch pairSearch =
-            search->startPairSearch(i->x);
-        gmx::AnalysisNeighborhoodPair       pair;
-        while (pairSearch.findNextPair(&pair))
+        remainingTestPositions.insert(i);
+    }
+    gmx::AnalysisNeighborhoodPairSearch pairSearch
+        = search->startPairSearch(pos);
+    gmx::AnalysisNeighborhoodPair       pair;
+    // TODO: There is an ordering assumption here that may break in the future:
+    // all pairs for a test position are assumed to be returned consencutively.
+    RefPairList refPairs;
+    int         prevTestPos = -1;
+    while (pairSearch.findNextPair(&pair))
+    {
+        if (pair.testIndex() != prevTestPos)
         {
-            EXPECT_EQ(0, pair.testIndex());
-            if (checkSet.erase(pair.refIndex()) == 0)
+            if (prevTestPos != -1)
+            {
+                checkAllPairsFound(refPairs);
+            }
+            const int testIndex = pair.testIndex();
+            if (remainingTestPositions.count(testIndex) == 0)
             {
-                // TODO: Check whether the same pair was returned more than
-                // once and give a better error message if so.
                 ADD_FAILURE()
-                << "Expected: Position " << pair.refIndex()
-                << " is within cutoff.\n"
-                << "  Actual: It is not.";
+                << "Pairs for test position " << testIndex
+                << " are returned more than once.";
+            }
+            remainingTestPositions.erase(testIndex);
+            refPairs = data.testPositions_[testIndex].refPairs;
+            if (excls != NULL)
+            {
+                ExclusionsHelper::markExcludedPairs(&refPairs, testIndex, excls);
             }
+            prevTestPos = testIndex;
+        }
+
+        NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(),
+                                                       sqrt(pair.distance2()));
+        RefPairList::iterator               foundRefPair
+            = std::lower_bound(refPairs.begin(), refPairs.end(), searchPair);
+        if (foundRefPair == refPairs.end() || foundRefPair->refIndex != pair.refIndex())
+        {
+            ADD_FAILURE()
+            << "Expected: Pair (ref: " << pair.refIndex() << ", test: "
+            << pair.testIndex() << ") is not within the cutoff.\n"
+            << "  Actual: It is returned.";
+        }
+        else if (foundRefPair->bExcluded)
+        {
+            ADD_FAILURE()
+            << "Expected: Pair (ref: " << pair.refIndex() << ", test: "
+            << pair.testIndex() << ") is excluded from the search.\n"
+            << "  Actual: It is returned.";
+        }
+        else if (foundRefPair->bFound)
+        {
+            ADD_FAILURE()
+            << "Expected: Pair (ref: " << pair.refIndex() << ", test: "
+            << pair.testIndex() << ") is returned only once.\n"
+            << "  Actual: It is returned multiple times.";
+        }
+        else
+        {
+            foundRefPair->bFound = true;
+            EXPECT_REAL_EQ_TOL(foundRefPair->distance, searchPair.distance,
+                               gmx::test::ulpTolerance(64))
+            << "Distance computed by the neighborhood search does not match.";
+        }
+    }
+    checkAllPairsFound(refPairs);
+    for (std::set<int>::const_iterator i = remainingTestPositions.begin();
+         i != remainingTestPositions.end(); ++i)
+    {
+        if (!data.testPositions_[*i].refPairs.empty())
+        {
+            ADD_FAILURE()
+            << "Expected: Pairs would be returned for test position " << *i << ".\n"
+            << "  Actual: None were returned.";
+            break;
         }
-        EXPECT_TRUE(checkSet.empty()) << "Some positions were not returned by the pair search.";
     }
 }
 
@@ -372,6 +624,32 @@ class RandomBoxFullPBCData
         NeighborhoodSearchTestData data_;
 };
 
+class RandomBoxXYFullPBCData
+{
+    public:
+        static const NeighborhoodSearchTestData &get()
+        {
+            static RandomBoxXYFullPBCData singleton;
+            return singleton.data_;
+        }
+
+        RandomBoxXYFullPBCData() : data_(54321, 1.0)
+        {
+            data_.box_[XX][XX] = 10.0;
+            data_.box_[YY][YY] = 5.0;
+            data_.box_[ZZ][ZZ] = 7.0;
+            // TODO: Consider whether manually picking some positions would give better
+            // test coverage.
+            data_.generateRandomRefPositions(1000);
+            data_.generateRandomTestPositions(100);
+            set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
+            data_.computeReferencesXY(&data_.pbc_);
+        }
+
+    private:
+        NeighborhoodSearchTestData data_;
+};
+
 class RandomTriclinicFullPBCData
 {
     public:
@@ -493,6 +771,24 @@ TEST_F(NeighborhoodSearchTest, GridSearch2DPBC)
     testPairSearch(&search, data);
 }
 
+TEST_F(NeighborhoodSearchTest, GridSearchXYBox)
+{
+    const NeighborhoodSearchTestData &data = RandomBoxXYFullPBCData::get();
+
+    nb_.setCutoff(data.cutoff_);
+    nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
+    nb_.setXYMode(true);
+    gmx::AnalysisNeighborhoodSearch search =
+        nb_.initSearch(&data.pbc_, data.refPositions());
+    // Currently, grid searching not supported with XY.
+    //ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
+
+    testIsWithin(&search, data);
+    testMinimumDistance(&search, data);
+    testNearestPoint(&search, data);
+    testPairSearch(&search, data);
+}
+
 TEST_F(NeighborhoodSearchTest, HandlesConcurrentSearches)
 {
     const NeighborhoodSearchTestData &data = TrivialTestData::get();
@@ -511,13 +807,21 @@ TEST_F(NeighborhoodSearchTest, HandlesConcurrentSearches)
     testPairSearch(&search2, data);
 
     gmx::AnalysisNeighborhoodPair pair;
-    pairSearch1.findNextPair(&pair);
+    ASSERT_TRUE(pairSearch1.findNextPair(&pair))
+    << "Test data did not contain any pairs for position 0 (problem in the test).";
     EXPECT_EQ(0, pair.testIndex());
-    EXPECT_TRUE(data.testPositions_[0].refPairs.count(pair.refIndex()) == 1);
+    {
+        NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), sqrt(pair.distance2()));
+        EXPECT_TRUE(data.containsPair(0, searchPair));
+    }
 
-    pairSearch2.findNextPair(&pair);
+    ASSERT_TRUE(pairSearch2.findNextPair(&pair))
+    << "Test data did not contain any pairs for position 1 (problem in the test).";
     EXPECT_EQ(1, pair.testIndex());
-    EXPECT_TRUE(data.testPositions_[1].refPairs.count(pair.refIndex()) == 1);
+    {
+        NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), sqrt(pair.distance2()));
+        EXPECT_TRUE(data.containsPair(1, searchPair));
+    }
 }
 
 TEST_F(NeighborhoodSearchTest, HandlesSkippingPairs)
@@ -542,10 +846,51 @@ TEST_F(NeighborhoodSearchTest, HandlesSkippingPairs)
             ++currentIndex;
         }
         EXPECT_EQ(currentIndex, pair.testIndex());
-        EXPECT_TRUE(data.testPositions_[currentIndex].refPairs.count(pair.refIndex()) == 1);
+        NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), sqrt(pair.distance2()));
+        EXPECT_TRUE(data.containsPair(currentIndex, searchPair));
         pairSearch.skipRemainingPairsForTestPosition();
         ++currentIndex;
     }
 }
 
+TEST_F(NeighborhoodSearchTest, SimpleSearchExclusions)
+{
+    const NeighborhoodSearchTestData &data = RandomBoxFullPBCData::get();
+
+    ExclusionsHelper                  helper(data.refPosCount_, data.testPositions_.size());
+    helper.generateExclusions();
+
+    nb_.setCutoff(data.cutoff_);
+    nb_.setTopologyExclusions(helper.exclusions());
+    nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
+    gmx::AnalysisNeighborhoodSearch search =
+        nb_.initSearch(&data.pbc_,
+                       data.refPositions().exclusionIds(helper.refPosIds()));
+    ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
+
+    testPairSearchFull(&search, data,
+                       data.testPositions().exclusionIds(helper.testPosIds()),
+                       helper.exclusions());
+}
+
+TEST_F(NeighborhoodSearchTest, GridSearchExclusions)
+{
+    const NeighborhoodSearchTestData &data = RandomBoxFullPBCData::get();
+
+    ExclusionsHelper                  helper(data.refPosCount_, data.testPositions_.size());
+    helper.generateExclusions();
+
+    nb_.setCutoff(data.cutoff_);
+    nb_.setTopologyExclusions(helper.exclusions());
+    nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
+    gmx::AnalysisNeighborhoodSearch search =
+        nb_.initSearch(&data.pbc_,
+                       data.refPositions().exclusionIds(helper.refPosIds()));
+    ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
+
+    testPairSearchFull(&search, data,
+                       data.testPositions().exclusionIds(helper.testPosIds()),
+                       helper.exclusions());
+}
+
 } // namespace