Basic exclusion support for analysis nbsearch
authorTeemu Murtola <teemu.murtola@gmail.com>
Sat, 9 Aug 2014 04:21:44 +0000 (07:21 +0300)
committerTeemu Murtola <teemu.murtola@gmail.com>
Sun, 10 Aug 2014 04:34:18 +0000 (07:34 +0300)
Add support for setting exclusions for AnalysisNeighborhood-based
searches.  Required for some functionality in gmx rdf, and the interface
is focused on satisfying that need.  Efficiency could potentially be
improved, as well as sanity checks on the input.  Now that the basic
machinery is there, it can be extended and the implementation improved
without hopefully affecting calling code significantly.

Change-Id: Ia667ee61ed1bb3d1171c22d7752f22517b3ccda8

src/gromacs/selection/nbsearch.cpp
src/gromacs/selection/nbsearch.h
src/gromacs/selection/selection.cpp
src/gromacs/selection/tests/nbsearch.cpp

index dd943ab8f093d3a455e8b6a2905a04d21304b835..b50c4dbd539c81fa63cc38077b496cf05976995f 100644 (file)
@@ -70,6 +70,7 @@
 #include "gromacs/math/vec.h"
 #include "gromacs/pbcutil/pbc.h"
 #include "gromacs/selection/position.h"
+#include "gromacs/topology/block.h"
 #include "gromacs/utility/arrayref.h"
 #include "gromacs/utility/exceptions.h"
 #include "gromacs/utility/gmxassert.h"
@@ -102,11 +103,13 @@ class AnalysisNeighborhoodSearchImpl
          *
          * \param[in]     mode      Search mode to use.
          * \param[in]     bXY       Whether to use 2D searching.
+         * \param[in]     excls     Exclusions.
          * \param[in]     pbc       PBC information.
          * \param[in]     positions Set of reference positions.
          */
         void init(AnalysisNeighborhood::SearchMode     mode,
                   bool                                 bXY,
+                  const t_blocka                      *excls,
                   const t_pbc                         *pbc,
                   const AnalysisNeighborhoodPositions &positions);
         PairSearchImplPointer getPairSearch();
@@ -168,16 +171,13 @@ class AnalysisNeighborhoodSearchImpl
         int                     nref_;
         //! Reference point positions.
         const rvec             *xref_;
-        //! Reference position ids (NULL if not available).
-        const int              *refid_;
+        //! Reference position exclusion IDs.
+        const int              *refExclusionIds_;
+        //! Exclusions.
+        const t_blocka         *excls_;
         //! PBC data.
         t_pbc                   pbc_;
 
-        //! Number of excluded reference positions for current test particle.
-        int                     nexcl_;
-        //! Exclusions for current test particle.
-        int                    *excl_;
-
         //! Whether grid searching is actually used for the current positions.
         bool                    bGrid_;
         //! Array allocated for storing in-unit-cell reference positions.
@@ -215,6 +215,9 @@ class AnalysisNeighborhoodPairSearchImpl
         explicit AnalysisNeighborhoodPairSearchImpl(const AnalysisNeighborhoodSearchImpl &search)
             : search_(search)
         {
+            testExclusionIds_ = NULL;
+            nexcl_            = 0;
+            excl_             = NULL;
             clear_rvec(xtest_);
             clear_ivec(testcell_);
             reset(-1);
@@ -240,6 +243,12 @@ class AnalysisNeighborhoodPairSearchImpl
         const AnalysisNeighborhoodSearchImpl   &search_;
         //! Reference to the test positions.
         ConstArrayRef<rvec>                     testPositions_;
+        //! Reference to the test exclusion indices.
+        const int                              *testExclusionIds_;
+        //! Number of excluded reference positions for current test particle.
+        int                                     nexcl_;
+        //! Exclusions for current test particle.
+        const int                              *excl_;
         //! Index of the currently active test position in \p testPositions_.
         int                                     testIndex_;
         //! Stores test position during a pair loop.
@@ -276,14 +285,11 @@ AnalysisNeighborhoodSearchImpl::AnalysisNeighborhoodSearchImpl(real cutoff)
     cutoff2_        = sqr(cutoff_);
     bXY_            = false;
 
-    nref_           = 0;
-    xref_           = NULL;
-    refid_          = NULL;
+    nref_            = 0;
+    xref_            = NULL;
+    refExclusionIds_ = NULL;
     std::memset(&pbc_, 0, sizeof(pbc_));
 
-    nexcl_          = 0;
-    excl_           = NULL;
-
     bGrid_          = false;
 
     xref_alloc_     = NULL;
@@ -475,6 +481,7 @@ void AnalysisNeighborhoodSearchImpl::addToGridCell(const ivec cell, int i)
 void AnalysisNeighborhoodSearchImpl::init(
         AnalysisNeighborhood::SearchMode     mode,
         bool                                 bXY,
+        const t_blocka                      *excls,
         const t_pbc                         *pbc,
         const AnalysisNeighborhoodPositions &positions)
 {
@@ -545,29 +552,17 @@ void AnalysisNeighborhoodSearchImpl::init(
     {
         xref_ = positions.x_;
     }
-    // TODO: Once exclusions are supported, this may need to be initialized.
-    refid_ = NULL;
-}
-
-#if 0
-/*! \brief
- * Sets the exclusions for the next neighborhood search.
- *
- * \param[in,out] d     Neighborhood search data structure.
- * \param[in]     nexcl Number of reference positions to exclude from next
- *      search.
- * \param[in]     excl  Indices of reference positions to exclude.
- *
- * The set exclusions remain in effect until the next call of this function.
- */
-void
-gmx_ana_nbsearch_set_excl(gmx_ana_nbsearch_t *d, int nexcl, int excl[])
-{
-
-    d->nexcl = nexcl;
-    d->excl  = excl;
+    excls_           = excls;
+    refExclusionIds_ = NULL;
+    if (excls != NULL)
+    {
+        // TODO: Check that the IDs are ascending, or remove the limitation.
+        refExclusionIds_ = positions.exclusionIds_;
+        GMX_RELEASE_ASSERT(refExclusionIds_ != NULL,
+                           "Exclusion IDs must be set for reference positions "
+                           "when exclusions are enabled");
+    }
 }
-#endif
 
 /********************************************************************
  * AnalysisNeighborhoodPairSearchImpl
@@ -586,6 +581,21 @@ void AnalysisNeighborhoodPairSearchImpl::reset(int testIndex)
                                             1, &xtest_);
             search_.mapPointToGridCell(xtest_, testcell_);
         }
+        if (search_.excls_ != NULL)
+        {
+            const int exclIndex  = testExclusionIds_[testIndex];
+            if (exclIndex < search_.excls_->nr)
+            {
+                const int startIndex = search_.excls_->index[exclIndex];
+                nexcl_ = search_.excls_->index[exclIndex + 1] - startIndex;
+                excl_  = &search_.excls_->a[startIndex];
+            }
+            else
+            {
+                nexcl_ = 0;
+                excl_  = NULL;
+            }
+        }
     }
     previ_     = -1;
     prevr2_    = 0.0;
@@ -605,34 +615,17 @@ void AnalysisNeighborhoodPairSearchImpl::nextTestPosition()
 
 bool AnalysisNeighborhoodPairSearchImpl::isExcluded(int j)
 {
-    if (exclind_ < search_.nexcl_)
+    if (exclind_ < nexcl_)
     {
-        if (search_.refid_)
+        const int refId = search_.refExclusionIds_[j];
+        while (exclind_ < nexcl_ && excl_[exclind_] < refId)
         {
-            while (exclind_ < search_.nexcl_
-                   && search_.excl_[exclind_] < search_.refid_[j])
-            {
-                ++exclind_;
-            }
-            if (exclind_ < search_.nexcl_
-                && search_.refid_[j] == search_.excl_[exclind_])
-            {
-                ++exclind_;
-                return true;
-            }
+            ++exclind_;
         }
-        else
+        if (exclind_ < nexcl_ && refId == excl_[exclind_])
         {
-            while (search_.bGrid_ && exclind_ < search_.nexcl_
-                   && search_.excl_[exclind_] < j)
-            {
-                ++exclind_;
-            }
-            if (search_.excl_[exclind_] == j)
-            {
-                ++exclind_;
-                return true;
-            }
+            ++exclind_;
+            return true;
         }
     }
     return false;
@@ -641,6 +634,9 @@ bool AnalysisNeighborhoodPairSearchImpl::isExcluded(int j)
 void AnalysisNeighborhoodPairSearchImpl::startSearch(
         const AnalysisNeighborhoodPositions &positions)
 {
+    testExclusionIds_ = positions.exclusionIds_;
+    GMX_RELEASE_ASSERT(search_.excls_ == NULL || testExclusionIds_ != NULL,
+                       "Exclusion IDs must be set when exclusions are enabled");
     if (positions.index_ < 0)
     {
         testPositions_ = ConstArrayRef<rvec>(positions.x_, positions.count_);
@@ -832,7 +828,7 @@ class AnalysisNeighborhood::Impl
         typedef AnalysisNeighborhoodSearch::ImplPointer SearchImplPointer;
         typedef std::vector<SearchImplPointer> SearchList;
 
-        Impl() : cutoff_(0), mode_(eSearchMode_Automatic), bXY_(false)
+        Impl() : cutoff_(0), excls_(NULL), mode_(eSearchMode_Automatic), bXY_(false)
         {
         }
         ~Impl()
@@ -850,6 +846,7 @@ class AnalysisNeighborhood::Impl
         tMPI::mutex             createSearchMutex_;
         SearchList              searchList_;
         real                    cutoff_;
+        const t_blocka         *excls_;
         SearchMode              mode_;
         bool                    bXY_;
 };
@@ -898,6 +895,13 @@ void AnalysisNeighborhood::setXYMode(bool bXY)
     impl_->bXY_ = bXY;
 }
 
+void AnalysisNeighborhood::setTopologyExclusions(const t_blocka *excls)
+{
+    GMX_RELEASE_ASSERT(impl_->searchList_.empty(),
+                       "Changing the exclusions after initSearch() not currently supported");
+    impl_->excls_ = excls;
+}
+
 void AnalysisNeighborhood::setMode(SearchMode mode)
 {
     impl_->mode_ = mode;
@@ -913,7 +917,7 @@ AnalysisNeighborhood::initSearch(const t_pbc                         *pbc,
                                  const AnalysisNeighborhoodPositions &positions)
 {
     Impl::SearchImplPointer search(impl_->getSearch());
-    search->init(mode(), impl_->bXY_, pbc, positions);
+    search->init(mode(), impl_->bXY_, impl_->excls_, pbc, positions);
     return AnalysisNeighborhoodSearch(search);
 }
 
index 79d06b9b02f8c575806151966890cfa1165bae5e..be25e7a3fdfd90f0f3e4e08df14fcc6d0eeb4387 100644 (file)
 #include <boost/shared_ptr.hpp>
 
 #include "../math/vectypes.h"
+#include "../utility/arrayref.h"
 #include "../utility/common.h"
 #include "../utility/gmxassert.h"
 #include "../utility/real.h"
 
+struct t_blocka;
 struct t_pbc;
 
 namespace gmx
@@ -106,17 +108,33 @@ class AnalysisNeighborhoodPositions
          * to methods that accept positions.
          */
         AnalysisNeighborhoodPositions(const rvec &x)
-            : count_(1), index_(-1), x_(&x)
+            : count_(1), index_(-1), x_(&x), exclusionIds_(NULL)
         {
         }
         /*! \brief
          * Initializes positions from an array of position vectors.
          */
         AnalysisNeighborhoodPositions(const rvec x[], int count)
-            : count_(count), index_(-1), x_(x)
+            : count_(count), index_(-1), x_(x), exclusionIds_(NULL)
         {
         }
 
+        /*! \brief
+         * Sets indices to use for mapping exclusions to these positions.
+         *
+         * The exclusion IDs can always be set, but they are ignored unless
+         * actual exclusions have been set with
+         * AnalysisNeighborhood::setTopologyExclusions().
+         */
+        AnalysisNeighborhoodPositions &
+        exclusionIds(ConstArrayRef<int> ids)
+        {
+            GMX_ASSERT(static_cast<int>(ids.size()) == count_,
+                       "Exclusion id array should match the number of positions");
+            exclusionIds_ = ids.data();
+            return *this;
+        }
+
         /*! \brief
          * Selects a single position to use from an array.
          *
@@ -137,6 +155,7 @@ class AnalysisNeighborhoodPositions
         int                     count_;
         int                     index_;
         const rvec             *x_;
+        const int              *exclusionIds_;
 
         //! To access the positions for initialization.
         friend class internal::AnalysisNeighborhoodSearchImpl;
@@ -167,12 +186,8 @@ class AnalysisNeighborhoodPositions
  * a single thread.
  *
  * \todo
- * Support for exclusions.
- * The 4.5/4.6 C API had very low-level support for exclusions, which was not
- * very convenient to use, and hadn't been tested much.  The internal code that
- * it used to do the exclusion during the search itself is still there, but it
- * needs more thought on what would be a convenient way to initialize it.
- * Can be implemented once there is need for it in some calling code.
+ * Generalize the exclusion machinery to make it easier to use for other cases
+ * than atom-atom exclusions from the topology.
  *
  * \inpublicapi
  * \ingroup module_selection
@@ -218,6 +233,23 @@ class AnalysisNeighborhood
          * Does not throw.
          */
         void setXYMode(bool bXY);
+        /*! \brief
+         * Sets atom exclusions from a topology.
+         *
+         * The \p excls structure specifies the exclusions from test positions
+         * to reference positions, i.e., a block starting at `excls->index[i]`
+         * specifies the exclusions for test position `i`, and the indices in
+         * `excls->a` are indices of the reference positions.  If `excls->nr`
+         * is smaller than a test position id, then such test positions do not
+         * have any exclusions.
+         * It is assumed that the indices within a block of indices in
+         * `excls->a` is ascending.
+         *
+         * Does not throw.
+         *
+         * \see AnalysisNeighborhoodPositions::exclusionIds()
+         */
+        void setTopologyExclusions(const t_blocka *excls);
         /*! \brief
          * Sets the algorithm to use for searching.
          *
index 795b08065e6f3f153558aa8f21ce19c19f3742a0..892496298aebe0d090dd194163a0d8b602e9c4a7 100644 (file)
@@ -247,8 +247,13 @@ SelectionData::restoreOriginalPositions(const t_topology *top)
 
 Selection::operator AnalysisNeighborhoodPositions() const
 {
-    return AnalysisNeighborhoodPositions(data().rawPositions_.x,
-                                         data().rawPositions_.count());
+    AnalysisNeighborhoodPositions pos(data().rawPositions_.x,
+                                      data().rawPositions_.count());
+    if (hasOnlyAtoms())
+    {
+        pos.exclusionIds(atomIndices());
+    }
+    return pos;
 }
 
 
@@ -347,9 +352,15 @@ Selection::printDebugInfo(FILE *fp, int nmaxind) const
 
 SelectionPosition::operator AnalysisNeighborhoodPositions() const
 {
-    return AnalysisNeighborhoodPositions(sel_->rawPositions_.x,
-                                         sel_->rawPositions_.count())
-               .selectSingleFromArray(i_);
+    AnalysisNeighborhoodPositions pos(sel_->rawPositions_.x,
+                                      sel_->rawPositions_.count());
+    if (sel_->hasOnlyAtoms())
+    {
+        // TODO: Move atomIndices() such that it can be reused here as well.
+        pos.exclusionIds(ConstArrayRef<int>(sel_->rawPositions_.m.mapb.a,
+                                            sel_->rawPositions_.m.mapb.nra));
+    }
+    return pos.selectSingleFromArray(i_);
 }
 
 } // namespace gmx
index f6c162ba2c7b5e8918a6160ff860910fc6d24129..370235bd3160a5a15b3eddc9ec952a7750f1b10e 100644 (file)
 
 #include <algorithm>
 #include <limits>
+#include <numeric>
 #include <vector>
 
 #include "gromacs/math/vec.h"
 #include "gromacs/pbcutil/pbc.h"
 #include "gromacs/random/random.h"
+#include "gromacs/topology/block.h"
 #include "gromacs/utility/smalloc.h"
 
 #include "testutils/testasserts.h"
@@ -74,7 +76,8 @@ class NeighborhoodSearchTestData
         struct RefPair
         {
             RefPair(int refIndex, real distance)
-                : refIndex(refIndex), distance(distance), bFound(false)
+                : refIndex(refIndex), distance(distance), bFound(false),
+                  bExcluded(false)
             {
             }
 
@@ -85,7 +88,12 @@ class NeighborhoodSearchTestData
 
             int                 refIndex;
             real                distance;
+            // The variables below are state variables that are only used
+            // during the actual testing after creating a copy of the reference
+            // pair list, not as part of the reference data.
+            // Simpler to have just a single structure for both purposes.
             bool                bFound;
+            bool                bExcluded;
         };
 
         struct TestPosition
@@ -278,6 +286,105 @@ void NeighborhoodSearchTestData::computeReferencesInternal(t_pbc *pbc, bool bXY)
     }
 }
 
+/********************************************************************
+ * ExclusionsHelper
+ */
+
+class ExclusionsHelper
+{
+    public:
+        static void markExcludedPairs(RefPairList *refPairs, int testIndex,
+                                      const t_blocka *excls);
+
+        ExclusionsHelper(int refPosCount, int testPosCount);
+
+        void generateExclusions();
+
+        const t_blocka *exclusions() const { return &excls_; }
+
+        gmx::ConstArrayRef<int> refPosIds() const
+        {
+            return gmx::ConstArrayRef<int>(exclusionIds_.begin(),
+                                           exclusionIds_.begin() + refPosCount_);
+        }
+        gmx::ConstArrayRef<int> testPosIds() const
+        {
+            return gmx::ConstArrayRef<int>(exclusionIds_.begin(),
+                                           exclusionIds_.begin() + testPosCount_);
+        }
+
+    private:
+        int              refPosCount_;
+        int              testPosCount_;
+        std::vector<int> exclusionIds_;
+        std::vector<int> exclsIndex_;
+        std::vector<int> exclsAtoms_;
+        t_blocka         excls_;
+};
+
+// static
+void ExclusionsHelper::markExcludedPairs(RefPairList *refPairs, int testIndex,
+                                         const t_blocka *excls)
+{
+    int count = 0;
+    for (int i = excls->index[testIndex]; i < excls->index[testIndex + 1]; ++i)
+    {
+        const int                           excludedIndex = excls->a[i];
+        NeighborhoodSearchTestData::RefPair searchPair(excludedIndex, 0.0);
+        RefPairList::iterator               excludedRefPair
+            = std::lower_bound(refPairs->begin(), refPairs->end(), searchPair);
+        if (excludedRefPair != refPairs->end()
+            && excludedRefPair->refIndex == excludedIndex)
+        {
+            excludedRefPair->bFound    = true;
+            excludedRefPair->bExcluded = true;
+            ++count;
+        }
+    }
+}
+
+ExclusionsHelper::ExclusionsHelper(int refPosCount, int testPosCount)
+    : refPosCount_(refPosCount), testPosCount_(testPosCount)
+{
+    // Generate an array of 0, 1, 2, ...
+    // TODO: Make the tests work also with non-trivial exclusion IDs,
+    // and test that.
+    exclusionIds_.resize(std::max(refPosCount, testPosCount), 1);
+    exclusionIds_[0] = 0;
+    std::partial_sum(exclusionIds_.begin(), exclusionIds_.end(),
+                     exclusionIds_.begin());
+
+    excls_.nr           = 0;
+    excls_.index        = NULL;
+    excls_.nra          = 0;
+    excls_.a            = NULL;
+    excls_.nalloc_index = 0;
+    excls_.nalloc_a     = 0;
+}
+
+void ExclusionsHelper::generateExclusions()
+{
+    // TODO: Consider a better set of test data, where the density of the
+    // particles would be higher, or where the exclusions would not be random,
+    // to make a higher percentage of the exclusions to actually be within the
+    // cutoff.
+    exclsIndex_.reserve(testPosCount_ + 1);
+    exclsAtoms_.reserve(testPosCount_ * 20);
+    exclsIndex_.push_back(0);
+    for (int i = 0; i < testPosCount_; ++i)
+    {
+        for (int j = 0; j < 20; ++j)
+        {
+            exclsAtoms_.push_back(i + j*3);
+        }
+        exclsIndex_.push_back(exclsAtoms_.size());
+    }
+    excls_.nr    = exclsIndex_.size();
+    excls_.index = &exclsIndex_[0];
+    excls_.nra   = exclsAtoms_.size();
+    excls_.a     = &exclsAtoms_[0];
+}
+
 /********************************************************************
  * NeighborhoodSearchTest
  */
@@ -293,6 +400,10 @@ class NeighborhoodSearchTest : public ::testing::Test
                               const NeighborhoodSearchTestData &data);
         void testPairSearch(gmx::AnalysisNeighborhoodSearch  *search,
                             const NeighborhoodSearchTestData &data);
+        void testPairSearchFull(gmx::AnalysisNeighborhoodSearch          *search,
+                                const NeighborhoodSearchTestData         &data,
+                                const gmx::AnalysisNeighborhoodPositions &pos,
+                                const t_blocka                           *excls);
 
         gmx::AnalysisNeighborhood        nb_;
 };
@@ -345,57 +456,118 @@ void NeighborhoodSearchTest::testNearestPoint(
     }
 }
 
+/*! \brief
+ * Helper function to check that all expected pairs were found.
+ */
+static void checkAllPairsFound(const RefPairList &refPairs)
+{
+    // This could be elegantly expressed with Google Mock matchers, but that
+    // has a significant effect on the runtime of the tests...
+    for (RefPairList::const_iterator i = refPairs.begin(); i != refPairs.end(); ++i)
+    {
+        if (!i->bFound)
+        {
+            ADD_FAILURE()
+            << "Some pairs within the cutoff were not found.";
+            break;
+        }
+    }
+}
+
 void NeighborhoodSearchTest::testPairSearch(
         gmx::AnalysisNeighborhoodSearch  *search,
         const NeighborhoodSearchTestData &data)
 {
-    NeighborhoodSearchTestData::TestPositionList::const_iterator i;
-    // TODO: Test also searching all the test positions in a single search;
-    // currently the implementation just contains this loop, though, but in
-    // the future that may trigger a different code path.
-    for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
+    testPairSearchFull(search, data, data.testPositions(), NULL);
+}
+
+void NeighborhoodSearchTest::testPairSearchFull(
+        gmx::AnalysisNeighborhoodSearch          *search,
+        const NeighborhoodSearchTestData         &data,
+        const gmx::AnalysisNeighborhoodPositions &pos,
+        const t_blocka                           *excls)
+{
+    // TODO: Some parts of this code do not work properly if pos does not
+    // contain all the test positions.
+    std::set<int> remainingTestPositions;
+    for (size_t i = 0; i < data.testPositions_.size(); ++i)
+    {
+        remainingTestPositions.insert(i);
+    }
+    gmx::AnalysisNeighborhoodPairSearch pairSearch
+        = search->startPairSearch(pos);
+    gmx::AnalysisNeighborhoodPair       pair;
+    // TODO: There is an ordering assumption here that may break in the future:
+    // all pairs for a test position are assumed to be returned consencutively.
+    RefPairList refPairs;
+    int         prevTestPos = -1;
+    while (pairSearch.findNextPair(&pair))
     {
-        RefPairList                         refPairs = i->refPairs;
-        gmx::AnalysisNeighborhoodPairSearch pairSearch
-            = search->startPairSearch(i->x);
-        gmx::AnalysisNeighborhoodPair       pair;
-        while (pairSearch.findNextPair(&pair))
+        if (pair.testIndex() != prevTestPos)
         {
-            EXPECT_EQ(0, pair.testIndex());
-            NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(),
-                                                           sqrt(pair.distance2()));
-            RefPairList::iterator               foundRefPair
-                = std::lower_bound(refPairs.begin(), refPairs.end(), searchPair);
-            if (foundRefPair == refPairs.end() || foundRefPair->refIndex != pair.refIndex())
+            if (prevTestPos != -1)
             {
-                ADD_FAILURE()
-                << "Expected: Position " << pair.refIndex()
-                << " is within cutoff.\n"
-                << "  Actual: It is not.";
+                checkAllPairsFound(refPairs);
             }
-            else if (foundRefPair->bFound)
+            const int testIndex = pair.testIndex();
+            if (remainingTestPositions.count(testIndex) == 0)
             {
                 ADD_FAILURE()
-                << "Expected: Position " << pair.refIndex()
-                << " is returned only once.\n"
-                << "  Actual: It is returned multiple times.";
+                << "Pairs for test position " << testIndex
+                << " are returned more than once.";
             }
-            else
+            remainingTestPositions.erase(testIndex);
+            refPairs = data.testPositions_[testIndex].refPairs;
+            if (excls != NULL)
             {
-                foundRefPair->bFound = true;
-                EXPECT_REAL_EQ_TOL(foundRefPair->distance, searchPair.distance,
-                                   gmx::test::ulpTolerance(64));
+                ExclusionsHelper::markExcludedPairs(&refPairs, testIndex, excls);
             }
+            prevTestPos = testIndex;
         }
-        for (RefPairList::const_iterator j = refPairs.begin(); j != refPairs.end(); ++j)
+
+        NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(),
+                                                       sqrt(pair.distance2()));
+        RefPairList::iterator               foundRefPair
+            = std::lower_bound(refPairs.begin(), refPairs.end(), searchPair);
+        if (foundRefPair == refPairs.end() || foundRefPair->refIndex != pair.refIndex())
         {
-            if (!j->bFound)
-            {
-                ADD_FAILURE()
-                << "Expected: All pairs within cutoff will be returned.\n"
-                << "  Actual: Position " << j->refIndex << " is not found.";
-                break;
-            }
+            ADD_FAILURE()
+            << "Expected: Pair (ref: " << pair.refIndex() << ", test: "
+            << pair.testIndex() << ") is not within the cutoff.\n"
+            << "  Actual: It is returned.";
+        }
+        else if (foundRefPair->bExcluded)
+        {
+            ADD_FAILURE()
+            << "Expected: Pair (ref: " << pair.refIndex() << ", test: "
+            << pair.testIndex() << ") is excluded from the search.\n"
+            << "  Actual: It is returned.";
+        }
+        else if (foundRefPair->bFound)
+        {
+            ADD_FAILURE()
+            << "Expected: Pair (ref: " << pair.refIndex() << ", test: "
+            << pair.testIndex() << ") is returned only once.\n"
+            << "  Actual: It is returned multiple times.";
+        }
+        else
+        {
+            foundRefPair->bFound = true;
+            EXPECT_REAL_EQ_TOL(foundRefPair->distance, searchPair.distance,
+                               gmx::test::ulpTolerance(64))
+            << "Distance computed by the neighborhood search does not match.";
+        }
+    }
+    checkAllPairsFound(refPairs);
+    for (std::set<int>::const_iterator i = remainingTestPositions.begin();
+         i != remainingTestPositions.end(); ++i)
+    {
+        if (!data.testPositions_[*i].refPairs.empty())
+        {
+            ADD_FAILURE()
+            << "Expected: Pairs would be returned for test position " << *i << ".\n"
+            << "  Actual: None were returned.";
+            break;
         }
     }
 }
@@ -637,14 +809,16 @@ TEST_F(NeighborhoodSearchTest, HandlesConcurrentSearches)
     testPairSearch(&search2, data);
 
     gmx::AnalysisNeighborhoodPair pair;
-    pairSearch1.findNextPair(&pair);
+    ASSERT_TRUE(pairSearch1.findNextPair(&pair))
+    << "Test data did not contain any pairs for position 0 (problem in the test).";
     EXPECT_EQ(0, pair.testIndex());
     {
         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), sqrt(pair.distance2()));
         EXPECT_TRUE(data.containsPair(0, searchPair));
     }
 
-    pairSearch2.findNextPair(&pair);
+    ASSERT_TRUE(pairSearch2.findNextPair(&pair))
+    << "Test data did not contain any pairs for position 1 (problem in the test).";
     EXPECT_EQ(1, pair.testIndex());
     {
         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), sqrt(pair.distance2()));
@@ -681,4 +855,44 @@ TEST_F(NeighborhoodSearchTest, HandlesSkippingPairs)
     }
 }
 
+TEST_F(NeighborhoodSearchTest, SimpleSearchExclusions)
+{
+    const NeighborhoodSearchTestData &data = RandomBoxFullPBCData::get();
+
+    ExclusionsHelper                  helper(data.refPosCount_, data.testPositions_.size());
+    helper.generateExclusions();
+
+    nb_.setCutoff(data.cutoff_);
+    nb_.setTopologyExclusions(helper.exclusions());
+    nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
+    gmx::AnalysisNeighborhoodSearch search =
+        nb_.initSearch(&data.pbc_,
+                       data.refPositions().exclusionIds(helper.refPosIds()));
+    ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
+
+    testPairSearchFull(&search, data,
+                       data.testPositions().exclusionIds(helper.testPosIds()),
+                       helper.exclusions());
+}
+
+TEST_F(NeighborhoodSearchTest, GridSearchExclusions)
+{
+    const NeighborhoodSearchTestData &data = RandomBoxFullPBCData::get();
+
+    ExclusionsHelper                  helper(data.refPosCount_, data.testPositions_.size());
+    helper.generateExclusions();
+
+    nb_.setCutoff(data.cutoff_);
+    nb_.setTopologyExclusions(helper.exclusions());
+    nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
+    gmx::AnalysisNeighborhoodSearch search =
+        nb_.initSearch(&data.pbc_,
+                       data.refPositions().exclusionIds(helper.refPosIds()));
+    ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
+
+    testPairSearchFull(&search, data,
+                       data.testPositions().exclusionIds(helper.testPosIds()),
+                       helper.exclusions());
+}
+
 } // namespace