Basic support for 2D neighborhood search for analysis
[alexxy/gromacs.git] / src / gromacs / selection / tests / nbsearch.cpp
index 998c80371bf6e3e7f89b0f551e0ea7beab346ba9..f6c162ba2c7b5e8918a6160ff860910fc6d24129 100644 (file)
  * \author Teemu Murtola <teemu.murtola@gmail.com>
  * \ingroup module_selection
  */
+#include "gromacs/selection/nbsearch.h"
+
 #include <gtest/gtest.h>
 
 #include <cmath>
 
+#include <algorithm>
 #include <limits>
-#include <set>
 #include <vector>
 
 #include "gromacs/math/vec.h"
 #include "gromacs/pbcutil/pbc.h"
 #include "gromacs/random/random.h"
-#include "gromacs/selection/nbsearch.h"
 #include "gromacs/utility/smalloc.h"
 
 #include "testutils/testasserts.h"
@@ -70,6 +71,23 @@ namespace
 class NeighborhoodSearchTestData
 {
     public:
+        struct RefPair
+        {
+            RefPair(int refIndex, real distance)
+                : refIndex(refIndex), distance(distance), bFound(false)
+            {
+            }
+
+            bool operator<(const RefPair &other) const
+            {
+                return refIndex < other.refIndex;
+            }
+
+            int                 refIndex;
+            real                distance;
+            bool                bFound;
+        };
+
         struct TestPosition
         {
             TestPosition() : refMinDist(0.0), refNearestPoint(-1)
@@ -82,11 +100,12 @@ class NeighborhoodSearchTestData
                 copy_rvec(x, this->x);
             }
 
-            rvec                x;
-            real                refMinDist;
-            int                 refNearestPoint;
-            std::set<int>       refPairs;
+            rvec                 x;
+            real                 refMinDist;
+            int                  refNearestPoint;
+            std::vector<RefPair> refPairs;
         };
+
         typedef std::vector<TestPosition> TestPositionList;
 
         NeighborhoodSearchTestData(int seed, real cutoff);
@@ -123,7 +142,26 @@ class NeighborhoodSearchTestData
         void generateRandomPosition(rvec x);
         void generateRandomRefPositions(int count);
         void generateRandomTestPositions(int count);
-        void computeReferences(t_pbc *pbc);
+        void computeReferences(t_pbc *pbc)
+        {
+            computeReferencesInternal(pbc, false);
+        }
+        void computeReferencesXY(t_pbc *pbc)
+        {
+            computeReferencesInternal(pbc, true);
+        }
+
+        bool containsPair(int testIndex, const RefPair &pair) const
+        {
+            const std::vector<RefPair>          &refPairs = testPositions_[testIndex].refPairs;
+            std::vector<RefPair>::const_iterator foundRefPair
+                = std::lower_bound(refPairs.begin(), refPairs.end(), pair);
+            if (foundRefPair == refPairs.end() || foundRefPair->refIndex != pair.refIndex)
+            {
+                return false;
+            }
+            return true;
+        }
 
         gmx_rng_t                        rng_;
         real                             cutoff_;
@@ -134,9 +172,14 @@ class NeighborhoodSearchTestData
         TestPositionList                 testPositions_;
 
     private:
+        void computeReferencesInternal(t_pbc *pbc, bool bXY);
+
         mutable rvec                    *testPos_;
 };
 
+//! Shorthand for a collection of reference pairs.
+typedef std::vector<NeighborhoodSearchTestData::RefPair> RefPairList;
+
 NeighborhoodSearchTestData::NeighborhoodSearchTestData(int seed, real cutoff)
     : rng_(NULL), cutoff_(cutoff), refPosCount_(0), refPos_(NULL), testPos_(NULL)
 {
@@ -190,7 +233,7 @@ void NeighborhoodSearchTestData::generateRandomTestPositions(int count)
     }
 }
 
-void NeighborhoodSearchTestData::computeReferences(t_pbc *pbc)
+void NeighborhoodSearchTestData::computeReferencesInternal(t_pbc *pbc, bool bXY)
 {
     real cutoff = cutoff_;
     if (cutoff <= 0)
@@ -214,7 +257,11 @@ void NeighborhoodSearchTestData::computeReferences(t_pbc *pbc)
             {
                 rvec_sub(i->x, refPos_[j], dx);
             }
-            const real dist = norm(dx);
+            // TODO: This may not work intuitively for 2D with the third box
+            // vector not parallel to the Z axis, but neither does the actual
+            // neighborhood search.
+            const real dist =
+                !bXY ? norm(dx) : sqrt(sqr(dx[XX]) + sqr(dx[YY]));
             if (dist < i->refMinDist)
             {
                 i->refMinDist      = dist;
@@ -222,7 +269,10 @@ void NeighborhoodSearchTestData::computeReferences(t_pbc *pbc)
             }
             if (dist <= cutoff)
             {
-                i->refPairs.insert(j);
+                RefPair pair(j, dist);
+                GMX_RELEASE_ASSERT(i->refPairs.empty() || i->refPairs.back() < pair,
+                                   "Reference pairs should be generated in sorted order");
+                i->refPairs.push_back(pair);
             }
         }
     }
@@ -285,6 +335,8 @@ void NeighborhoodSearchTest::testNearestPoint(
         {
             EXPECT_EQ(i->refNearestPoint, pair.refIndex());
             EXPECT_EQ(0, pair.testIndex());
+            EXPECT_REAL_EQ_TOL(i->refMinDist, sqrt(pair.distance2()),
+                               gmx::test::ulpTolerance(64));
         }
         else
         {
@@ -298,26 +350,53 @@ void NeighborhoodSearchTest::testPairSearch(
         const NeighborhoodSearchTestData &data)
 {
     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
+    // TODO: Test also searching all the test positions in a single search;
+    // currently the implementation just contains this loop, though, but in
+    // the future that may trigger a different code path.
     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
     {
-        std::set<int> checkSet                         = i->refPairs;
-        gmx::AnalysisNeighborhoodPairSearch pairSearch =
-            search->startPairSearch(i->x);
+        RefPairList                         refPairs = i->refPairs;
+        gmx::AnalysisNeighborhoodPairSearch pairSearch
+            search->startPairSearch(i->x);
         gmx::AnalysisNeighborhoodPair       pair;
         while (pairSearch.findNextPair(&pair))
         {
             EXPECT_EQ(0, pair.testIndex());
-            if (checkSet.erase(pair.refIndex()) == 0)
+            NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(),
+                                                           sqrt(pair.distance2()));
+            RefPairList::iterator               foundRefPair
+                = std::lower_bound(refPairs.begin(), refPairs.end(), searchPair);
+            if (foundRefPair == refPairs.end() || foundRefPair->refIndex != pair.refIndex())
             {
-                // TODO: Check whether the same pair was returned more than
-                // once and give a better error message if so.
                 ADD_FAILURE()
                 << "Expected: Position " << pair.refIndex()
                 << " is within cutoff.\n"
                 << "  Actual: It is not.";
             }
+            else if (foundRefPair->bFound)
+            {
+                ADD_FAILURE()
+                << "Expected: Position " << pair.refIndex()
+                << " is returned only once.\n"
+                << "  Actual: It is returned multiple times.";
+            }
+            else
+            {
+                foundRefPair->bFound = true;
+                EXPECT_REAL_EQ_TOL(foundRefPair->distance, searchPair.distance,
+                                   gmx::test::ulpTolerance(64));
+            }
+        }
+        for (RefPairList::const_iterator j = refPairs.begin(); j != refPairs.end(); ++j)
+        {
+            if (!j->bFound)
+            {
+                ADD_FAILURE()
+                << "Expected: All pairs within cutoff will be returned.\n"
+                << "  Actual: Position " << j->refIndex << " is not found.";
+                break;
+            }
         }
-        EXPECT_TRUE(checkSet.empty()) << "Some positions were not returned by the pair search.";
     }
 }
 
@@ -375,6 +454,32 @@ class RandomBoxFullPBCData
         NeighborhoodSearchTestData data_;
 };
 
+class RandomBoxXYFullPBCData
+{
+    public:
+        static const NeighborhoodSearchTestData &get()
+        {
+            static RandomBoxXYFullPBCData singleton;
+            return singleton.data_;
+        }
+
+        RandomBoxXYFullPBCData() : data_(54321, 1.0)
+        {
+            data_.box_[XX][XX] = 10.0;
+            data_.box_[YY][YY] = 5.0;
+            data_.box_[ZZ][ZZ] = 7.0;
+            // TODO: Consider whether manually picking some positions would give better
+            // test coverage.
+            data_.generateRandomRefPositions(1000);
+            data_.generateRandomTestPositions(100);
+            set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
+            data_.computeReferencesXY(&data_.pbc_);
+        }
+
+    private:
+        NeighborhoodSearchTestData data_;
+};
+
 class RandomTriclinicFullPBCData
 {
     public:
@@ -496,6 +601,24 @@ TEST_F(NeighborhoodSearchTest, GridSearch2DPBC)
     testPairSearch(&search, data);
 }
 
+TEST_F(NeighborhoodSearchTest, GridSearchXYBox)
+{
+    const NeighborhoodSearchTestData &data = RandomBoxXYFullPBCData::get();
+
+    nb_.setCutoff(data.cutoff_);
+    nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
+    nb_.setXYMode(true);
+    gmx::AnalysisNeighborhoodSearch search =
+        nb_.initSearch(&data.pbc_, data.refPositions());
+    // Currently, grid searching not supported with XY.
+    //ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
+
+    testIsWithin(&search, data);
+    testMinimumDistance(&search, data);
+    testNearestPoint(&search, data);
+    testPairSearch(&search, data);
+}
+
 TEST_F(NeighborhoodSearchTest, HandlesConcurrentSearches)
 {
     const NeighborhoodSearchTestData &data = TrivialTestData::get();
@@ -516,11 +639,17 @@ TEST_F(NeighborhoodSearchTest, HandlesConcurrentSearches)
     gmx::AnalysisNeighborhoodPair pair;
     pairSearch1.findNextPair(&pair);
     EXPECT_EQ(0, pair.testIndex());
-    EXPECT_TRUE(data.testPositions_[0].refPairs.count(pair.refIndex()) == 1);
+    {
+        NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), sqrt(pair.distance2()));
+        EXPECT_TRUE(data.containsPair(0, searchPair));
+    }
 
     pairSearch2.findNextPair(&pair);
     EXPECT_EQ(1, pair.testIndex());
-    EXPECT_TRUE(data.testPositions_[1].refPairs.count(pair.refIndex()) == 1);
+    {
+        NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), sqrt(pair.distance2()));
+        EXPECT_TRUE(data.containsPair(1, searchPair));
+    }
 }
 
 TEST_F(NeighborhoodSearchTest, HandlesSkippingPairs)
@@ -545,7 +674,8 @@ TEST_F(NeighborhoodSearchTest, HandlesSkippingPairs)
             ++currentIndex;
         }
         EXPECT_EQ(currentIndex, pair.testIndex());
-        EXPECT_TRUE(data.testPositions_[currentIndex].refPairs.count(pair.refIndex()) == 1);
+        NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), sqrt(pair.distance2()));
+        EXPECT_TRUE(data.containsPair(currentIndex, searchPair));
         pairSearch.skipRemainingPairsForTestPosition();
         ++currentIndex;
     }