* \author Teemu Murtola <teemu.murtola@gmail.com>
* \ingroup module_selection
*/
-#include <gtest/gtest.h>
+#include "gmxpre.h"
+
+#include "gromacs/selection/nbsearch.h"
#include <cmath>
+#include <algorithm>
#include <limits>
-#include <set>
+#include <numeric>
#include <vector>
-#include "gromacs/random/random.h"
-#include "gromacs/legacyheaders/pbc.h"
-#include "gromacs/legacyheaders/smalloc.h"
-#include "gromacs/legacyheaders/vec.h"
+#include <gtest/gtest.h>
-#include "gromacs/selection/nbsearch.h"
+#include "gromacs/math/vec.h"
+#include "gromacs/pbcutil/pbc.h"
+#include "gromacs/random/random.h"
+#include "gromacs/topology/block.h"
+#include "gromacs/utility/smalloc.h"
+#include "gromacs/utility/stringutil.h"
#include "testutils/testasserts.h"
class NeighborhoodSearchTestData
{
public:
- struct TestPosition
+ struct RefPair
{
- TestPosition() : refMinDist(0.0), refNearestPoint(-1)
+ RefPair(int refIndex, real distance)
+ : refIndex(refIndex), distance(distance), bFound(false),
+ bExcluded(false)
+ {
+ }
+
+ bool operator<(const RefPair &other) const
{
- clear_rvec(x);
+ return refIndex < other.refIndex;
}
+
+ int refIndex;
+ real distance;
+ // The variables below are state variables that are only used
+ // during the actual testing after creating a copy of the reference
+ // pair list, not as part of the reference data.
+ // Simpler to have just a single structure for both purposes.
+ bool bFound;
+ bool bExcluded;
+ };
+
+ struct TestPosition
+ {
explicit TestPosition(const rvec x)
: refMinDist(0.0), refNearestPoint(-1)
{
copy_rvec(x, this->x);
}
- rvec x;
- real refMinDist;
- int refNearestPoint;
- std::set<int> refPairs;
+ rvec x;
+ real refMinDist;
+ int refNearestPoint;
+ std::vector<RefPair> refPairs;
};
+
typedef std::vector<TestPosition> TestPositionList;
NeighborhoodSearchTestData(int seed, real cutoff);
void generateRandomPosition(rvec x);
void generateRandomRefPositions(int count);
void generateRandomTestPositions(int count);
- void computeReferences(t_pbc *pbc);
+ void computeReferences(t_pbc *pbc)
+ {
+ computeReferencesInternal(pbc, false);
+ }
+ void computeReferencesXY(t_pbc *pbc)
+ {
+ computeReferencesInternal(pbc, true);
+ }
+
+ bool containsPair(int testIndex, const RefPair &pair) const
+ {
+ const std::vector<RefPair> &refPairs = testPositions_[testIndex].refPairs;
+ std::vector<RefPair>::const_iterator foundRefPair
+ = std::lower_bound(refPairs.begin(), refPairs.end(), pair);
+ if (foundRefPair == refPairs.end() || foundRefPair->refIndex != pair.refIndex)
+ {
+ return false;
+ }
+ return true;
+ }
gmx_rng_t rng_;
real cutoff_;
TestPositionList testPositions_;
private:
+ void computeReferencesInternal(t_pbc *pbc, bool bXY);
+
mutable rvec *testPos_;
};
+//! Shorthand for a collection of reference pairs.
+typedef std::vector<NeighborhoodSearchTestData::RefPair> RefPairList;
+
NeighborhoodSearchTestData::NeighborhoodSearchTestData(int seed, real cutoff)
: rng_(NULL), cutoff_(cutoff), refPosCount_(0), refPos_(NULL), testPos_(NULL)
{
}
}
-void NeighborhoodSearchTestData::computeReferences(t_pbc *pbc)
+void NeighborhoodSearchTestData::computeReferencesInternal(t_pbc *pbc, bool bXY)
{
real cutoff = cutoff_;
if (cutoff <= 0)
{
rvec_sub(i->x, refPos_[j], dx);
}
- const real dist = norm(dx);
+ // TODO: This may not work intuitively for 2D with the third box
+ // vector not parallel to the Z axis, but neither does the actual
+ // neighborhood search.
+ const real dist =
+ !bXY ? norm(dx) : sqrt(sqr(dx[XX]) + sqr(dx[YY]));
if (dist < i->refMinDist)
{
i->refMinDist = dist;
}
if (dist <= cutoff)
{
- i->refPairs.insert(j);
+ RefPair pair(j, dist);
+ GMX_RELEASE_ASSERT(i->refPairs.empty() || i->refPairs.back() < pair,
+ "Reference pairs should be generated in sorted order");
+ i->refPairs.push_back(pair);
}
}
}
}
+/********************************************************************
+ * ExclusionsHelper
+ */
+
+class ExclusionsHelper
+{
+ public:
+ static void markExcludedPairs(RefPairList *refPairs, int testIndex,
+ const t_blocka *excls);
+
+ ExclusionsHelper(int refPosCount, int testPosCount);
+
+ void generateExclusions();
+
+ const t_blocka *exclusions() const { return &excls_; }
+
+ gmx::ConstArrayRef<int> refPosIds() const
+ {
+ return gmx::constArrayRefFromVector<int>(exclusionIds_.begin(),
+ exclusionIds_.begin() + refPosCount_);
+ }
+ gmx::ConstArrayRef<int> testPosIds() const
+ {
+ return gmx::constArrayRefFromVector<int>(exclusionIds_.begin(),
+ exclusionIds_.begin() + testPosCount_);
+ }
+
+ private:
+ int refPosCount_;
+ int testPosCount_;
+ std::vector<int> exclusionIds_;
+ std::vector<int> exclsIndex_;
+ std::vector<int> exclsAtoms_;
+ t_blocka excls_;
+};
+
+// static
+void ExclusionsHelper::markExcludedPairs(RefPairList *refPairs, int testIndex,
+ const t_blocka *excls)
+{
+ int count = 0;
+ for (int i = excls->index[testIndex]; i < excls->index[testIndex + 1]; ++i)
+ {
+ const int excludedIndex = excls->a[i];
+ NeighborhoodSearchTestData::RefPair searchPair(excludedIndex, 0.0);
+ RefPairList::iterator excludedRefPair
+ = std::lower_bound(refPairs->begin(), refPairs->end(), searchPair);
+ if (excludedRefPair != refPairs->end()
+ && excludedRefPair->refIndex == excludedIndex)
+ {
+ excludedRefPair->bFound = true;
+ excludedRefPair->bExcluded = true;
+ ++count;
+ }
+ }
+}
+
+ExclusionsHelper::ExclusionsHelper(int refPosCount, int testPosCount)
+ : refPosCount_(refPosCount), testPosCount_(testPosCount)
+{
+ // Generate an array of 0, 1, 2, ...
+ // TODO: Make the tests work also with non-trivial exclusion IDs,
+ // and test that.
+ exclusionIds_.resize(std::max(refPosCount, testPosCount), 1);
+ exclusionIds_[0] = 0;
+ std::partial_sum(exclusionIds_.begin(), exclusionIds_.end(),
+ exclusionIds_.begin());
+
+ excls_.nr = 0;
+ excls_.index = NULL;
+ excls_.nra = 0;
+ excls_.a = NULL;
+ excls_.nalloc_index = 0;
+ excls_.nalloc_a = 0;
+}
+
+void ExclusionsHelper::generateExclusions()
+{
+ // TODO: Consider a better set of test data, where the density of the
+ // particles would be higher, or where the exclusions would not be random,
+ // to make a higher percentage of the exclusions to actually be within the
+ // cutoff.
+ exclsIndex_.reserve(testPosCount_ + 1);
+ exclsAtoms_.reserve(testPosCount_ * 20);
+ exclsIndex_.push_back(0);
+ for (int i = 0; i < testPosCount_; ++i)
+ {
+ for (int j = 0; j < 20; ++j)
+ {
+ exclsAtoms_.push_back(i + j*3);
+ }
+ exclsIndex_.push_back(exclsAtoms_.size());
+ }
+ excls_.nr = exclsIndex_.size();
+ excls_.index = &exclsIndex_[0];
+ excls_.nra = exclsAtoms_.size();
+ excls_.a = &exclsAtoms_[0];
+}
+
/********************************************************************
* NeighborhoodSearchTest
*/
const NeighborhoodSearchTestData &data);
void testPairSearch(gmx::AnalysisNeighborhoodSearch *search,
const NeighborhoodSearchTestData &data);
+ void testPairSearchFull(gmx::AnalysisNeighborhoodSearch *search,
+ const NeighborhoodSearchTestData &data,
+ const gmx::AnalysisNeighborhoodPositions &pos,
+ const t_blocka *excls);
gmx::AnalysisNeighborhood nb_;
};
{
EXPECT_EQ(i->refNearestPoint, pair.refIndex());
EXPECT_EQ(0, pair.testIndex());
+ EXPECT_REAL_EQ_TOL(i->refMinDist, sqrt(pair.distance2()),
+ gmx::test::ulpTolerance(64));
}
else
{
}
}
+//! Helper function for formatting test failure messages.
+std::string formatVector(const rvec x)
+{
+ return gmx::formatString("[%.3f, %.3f, %.3f]", x[XX], x[YY], x[ZZ]);
+}
+
+/*! \brief
+ * Helper function to check that all expected pairs were found.
+ */
+void checkAllPairsFound(const RefPairList &refPairs, const rvec refPos[],
+ int testPosIndex, const rvec testPos)
+{
+ // This could be elegantly expressed with Google Mock matchers, but that
+ // has a significant effect on the runtime of the tests...
+ int count = 0;
+ RefPairList::const_iterator first;
+ for (RefPairList::const_iterator i = refPairs.begin(); i != refPairs.end(); ++i)
+ {
+ if (!i->bFound)
+ {
+ ++count;
+ first = i;
+ }
+ }
+ if (count > 0)
+ {
+ ADD_FAILURE()
+ << "Some pairs (" << count << "/" << refPairs.size() << ") "
+ << "within the cutoff were not found. First pair:\n"
+ << " Ref: " << first->refIndex << " at "
+ << formatVector(refPos[first->refIndex]) << "\n"
+ << "Test: " << testPosIndex << " at " << formatVector(testPos) << "\n"
+ << "Dist: " << first->distance;
+ }
+}
+
void NeighborhoodSearchTest::testPairSearch(
gmx::AnalysisNeighborhoodSearch *search,
const NeighborhoodSearchTestData &data)
{
- NeighborhoodSearchTestData::TestPositionList::const_iterator i;
- for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
+ testPairSearchFull(search, data, data.testPositions(), NULL);
+}
+
+void NeighborhoodSearchTest::testPairSearchFull(
+ gmx::AnalysisNeighborhoodSearch *search,
+ const NeighborhoodSearchTestData &data,
+ const gmx::AnalysisNeighborhoodPositions &pos,
+ const t_blocka *excls)
+{
+ // TODO: Some parts of this code do not work properly if pos does not
+ // contain all the test positions.
+ std::set<int> remainingTestPositions;
+ for (size_t i = 0; i < data.testPositions_.size(); ++i)
{
- std::set<int> checkSet = i->refPairs;
- gmx::AnalysisNeighborhoodPairSearch pairSearch =
- search->startPairSearch(i->x);
- gmx::AnalysisNeighborhoodPair pair;
- while (pairSearch.findNextPair(&pair))
+ remainingTestPositions.insert(i);
+ }
+ gmx::AnalysisNeighborhoodPairSearch pairSearch
+ = search->startPairSearch(pos);
+ gmx::AnalysisNeighborhoodPair pair;
+ // TODO: There is an ordering assumption here that may break in the future:
+ // all pairs for a test position are assumed to be returned consencutively.
+ RefPairList refPairs;
+ int prevTestPos = -1;
+ while (pairSearch.findNextPair(&pair))
+ {
+ if (pair.testIndex() != prevTestPos)
{
- EXPECT_EQ(0, pair.testIndex());
- if (checkSet.erase(pair.refIndex()) == 0)
+ if (prevTestPos != -1)
+ {
+ checkAllPairsFound(refPairs, data.refPos_, prevTestPos,
+ data.testPositions_[prevTestPos].x);
+ }
+ const int testIndex = pair.testIndex();
+ if (remainingTestPositions.count(testIndex) == 0)
{
- // TODO: Check whether the same pair was returned more than
- // once and give a better error message if so.
ADD_FAILURE()
- << "Expected: Position " << pair.refIndex()
- << " is within cutoff.\n"
- << " Actual: It is not.";
+ << "Pairs for test position " << testIndex
+ << " are returned more than once.";
+ }
+ remainingTestPositions.erase(testIndex);
+ refPairs = data.testPositions_[testIndex].refPairs;
+ if (excls != NULL)
+ {
+ ExclusionsHelper::markExcludedPairs(&refPairs, testIndex, excls);
}
+ prevTestPos = testIndex;
+ }
+
+ NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(),
+ sqrt(pair.distance2()));
+ RefPairList::iterator foundRefPair
+ = std::lower_bound(refPairs.begin(), refPairs.end(), searchPair);
+ if (foundRefPair == refPairs.end() || foundRefPair->refIndex != pair.refIndex())
+ {
+ ADD_FAILURE()
+ << "Expected: Pair (ref: " << pair.refIndex() << ", test: "
+ << pair.testIndex() << ") is not within the cutoff.\n"
+ << " Actual: It is returned.";
+ }
+ else if (foundRefPair->bExcluded)
+ {
+ ADD_FAILURE()
+ << "Expected: Pair (ref: " << pair.refIndex() << ", test: "
+ << pair.testIndex() << ") is excluded from the search.\n"
+ << " Actual: It is returned.";
+ }
+ else if (foundRefPair->bFound)
+ {
+ ADD_FAILURE()
+ << "Expected: Pair (ref: " << pair.refIndex() << ", test: "
+ << pair.testIndex() << ") is returned only once.\n"
+ << " Actual: It is returned multiple times.";
+ }
+ else
+ {
+ foundRefPair->bFound = true;
+ EXPECT_REAL_EQ_TOL(foundRefPair->distance, searchPair.distance,
+ gmx::test::ulpTolerance(64))
+ << "Distance computed by the neighborhood search does not match.";
+ }
+ }
+ checkAllPairsFound(refPairs, data.refPos_, prevTestPos,
+ data.testPositions_[prevTestPos].x);
+ for (std::set<int>::const_iterator i = remainingTestPositions.begin();
+ i != remainingTestPositions.end(); ++i)
+ {
+ if (!data.testPositions_[*i].refPairs.empty())
+ {
+ ADD_FAILURE()
+ << "Expected: Pairs would be returned for test position " << *i << ".\n"
+ << " Actual: None were returned.";
+ break;
}
- EXPECT_TRUE(checkSet.empty()) << "Some positions were not returned by the pair search.";
}
}
NeighborhoodSearchTestData data_;
};
+class RandomBoxXYFullPBCData
+{
+ public:
+ static const NeighborhoodSearchTestData &get()
+ {
+ static RandomBoxXYFullPBCData singleton;
+ return singleton.data_;
+ }
+
+ RandomBoxXYFullPBCData() : data_(54321, 1.0)
+ {
+ data_.box_[XX][XX] = 10.0;
+ data_.box_[YY][YY] = 5.0;
+ data_.box_[ZZ][ZZ] = 7.0;
+ // TODO: Consider whether manually picking some positions would give better
+ // test coverage.
+ data_.generateRandomRefPositions(1000);
+ data_.generateRandomTestPositions(100);
+ set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
+ data_.computeReferencesXY(&data_.pbc_);
+ }
+
+ private:
+ NeighborhoodSearchTestData data_;
+};
+
class RandomTriclinicFullPBCData
{
public:
testPairSearch(&search, data);
}
+TEST_F(NeighborhoodSearchTest, GridSearchXYBox)
+{
+ const NeighborhoodSearchTestData &data = RandomBoxXYFullPBCData::get();
+
+ nb_.setCutoff(data.cutoff_);
+ nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
+ nb_.setXYMode(true);
+ gmx::AnalysisNeighborhoodSearch search =
+ nb_.initSearch(&data.pbc_, data.refPositions());
+ // Currently, grid searching not supported with XY.
+ //ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
+
+ testIsWithin(&search, data);
+ testMinimumDistance(&search, data);
+ testNearestPoint(&search, data);
+ testPairSearch(&search, data);
+}
+
TEST_F(NeighborhoodSearchTest, HandlesConcurrentSearches)
{
const NeighborhoodSearchTestData &data = TrivialTestData::get();
testPairSearch(&search2, data);
gmx::AnalysisNeighborhoodPair pair;
- pairSearch1.findNextPair(&pair);
+ ASSERT_TRUE(pairSearch1.findNextPair(&pair))
+ << "Test data did not contain any pairs for position 0 (problem in the test).";
EXPECT_EQ(0, pair.testIndex());
- EXPECT_TRUE(data.testPositions_[0].refPairs.count(pair.refIndex()) == 1);
+ {
+ NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), sqrt(pair.distance2()));
+ EXPECT_TRUE(data.containsPair(0, searchPair));
+ }
- pairSearch2.findNextPair(&pair);
+ ASSERT_TRUE(pairSearch2.findNextPair(&pair))
+ << "Test data did not contain any pairs for position 1 (problem in the test).";
EXPECT_EQ(1, pair.testIndex());
- EXPECT_TRUE(data.testPositions_[1].refPairs.count(pair.refIndex()) == 1);
+ {
+ NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), sqrt(pair.distance2()));
+ EXPECT_TRUE(data.containsPair(1, searchPair));
+ }
}
TEST_F(NeighborhoodSearchTest, HandlesSkippingPairs)
++currentIndex;
}
EXPECT_EQ(currentIndex, pair.testIndex());
- EXPECT_TRUE(data.testPositions_[currentIndex].refPairs.count(pair.refIndex()) == 1);
+ NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), sqrt(pair.distance2()));
+ EXPECT_TRUE(data.containsPair(currentIndex, searchPair));
pairSearch.skipRemainingPairsForTestPosition();
++currentIndex;
}
}
+TEST_F(NeighborhoodSearchTest, SimpleSearchExclusions)
+{
+ const NeighborhoodSearchTestData &data = RandomBoxFullPBCData::get();
+
+ ExclusionsHelper helper(data.refPosCount_, data.testPositions_.size());
+ helper.generateExclusions();
+
+ nb_.setCutoff(data.cutoff_);
+ nb_.setTopologyExclusions(helper.exclusions());
+ nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
+ gmx::AnalysisNeighborhoodSearch search =
+ nb_.initSearch(&data.pbc_,
+ data.refPositions().exclusionIds(helper.refPosIds()));
+ ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
+
+ testPairSearchFull(&search, data,
+ data.testPositions().exclusionIds(helper.testPosIds()),
+ helper.exclusions());
+}
+
+TEST_F(NeighborhoodSearchTest, GridSearchExclusions)
+{
+ const NeighborhoodSearchTestData &data = RandomBoxFullPBCData::get();
+
+ ExclusionsHelper helper(data.refPosCount_, data.testPositions_.size());
+ helper.generateExclusions();
+
+ nb_.setCutoff(data.cutoff_);
+ nb_.setTopologyExclusions(helper.exclusions());
+ nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
+ gmx::AnalysisNeighborhoodSearch search =
+ nb_.initSearch(&data.pbc_,
+ data.refPositions().exclusionIds(helper.refPosIds()));
+ ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
+
+ testPairSearchFull(&search, data,
+ data.testPositions().exclusionIds(helper.testPosIds()),
+ helper.exclusions());
+}
+
} // namespace