Basic support for 2D neighborhood search for analysis
authorTeemu Murtola <teemu.murtola@gmail.com>
Sun, 3 Aug 2014 04:27:36 +0000 (07:27 +0300)
committerTeemu Murtola <teemu.murtola@gmail.com>
Thu, 7 Aug 2014 14:50:25 +0000 (17:50 +0300)
Add support for doing a neighborhood search based on distances in the
X-Y plane into the analysis neighborhood search routines.  The basic
logic is lifted from 'gmx rdf -xy' and made a bit more general.

Currently, only basic searching is implemented, but grid-based should
not be too difficult either.  Later implementation of a grid-based
search should be transparent to callers.

Change-Id: I635d7f7a5eb0136d7a3a4c968ddaa34b03ae3bc7

src/gromacs/selection/nbsearch.cpp
src/gromacs/selection/nbsearch.h
src/gromacs/selection/tests/nbsearch.cpp

index 32b2d330f7a280bc1b79c42490b4d72af0f26080..dd943ab8f093d3a455e8b6a2905a04d21304b835 100644 (file)
  */
 #include "gromacs/selection/nbsearch.h"
 
-#include <math.h>
+#include <cmath>
+#include <cstring>
 
 #include <algorithm>
 #include <vector>
 
 #include "thread_mpi/mutex.h"
 
+#include "gromacs/legacyheaders/names.h"
+
 #include "gromacs/math/vec.h"
 #include "gromacs/pbcutil/pbc.h"
 #include "gromacs/selection/position.h"
 #include "gromacs/utility/arrayref.h"
+#include "gromacs/utility/exceptions.h"
 #include "gromacs/utility/gmxassert.h"
 #include "gromacs/utility/smalloc.h"
+#include "gromacs/utility/stringutil.h"
 
 namespace gmx
 {
@@ -96,10 +101,12 @@ class AnalysisNeighborhoodSearchImpl
          * Initializes the search with a given box and reference positions.
          *
          * \param[in]     mode      Search mode to use.
+         * \param[in]     bXY       Whether to use 2D searching.
          * \param[in]     pbc       PBC information.
          * \param[in]     positions Set of reference positions.
          */
         void init(AnalysisNeighborhood::SearchMode     mode,
+                  bool                                 bXY,
                   const t_pbc                         *pbc,
                   const AnalysisNeighborhoodPositions &positions);
         PairSearchImplPointer getPairSearch();
@@ -116,14 +123,14 @@ class AnalysisNeighborhoodSearchImpl
          * \param[in]     pbc  Information about the box.
          * \returns  false if grid search is not suitable.
          */
-        bool initGridCells(const t_pbc *pbc);
+        bool initGridCells(const t_pbc &pbc);
         /*! \brief
          * Sets ua a search grid for a given box.
          *
          * \param[in]     pbc  Information about the box.
          * \returns  false if grid search is not suitable.
          */
-        bool initGrid(const t_pbc *pbc);
+        bool initGrid(const t_pbc &pbc);
         /*! \brief
          * Maps a point into a grid cell.
          *
@@ -154,6 +161,8 @@ class AnalysisNeighborhoodSearchImpl
         real                    cutoff_;
         //! The cutoff squared.
         real                    cutoff2_;
+        //! Whether to do searching in XY plane only.
+        bool                    bXY_;
 
         //! Number of reference points for the current frame.
         int                     nref_;
@@ -162,7 +171,7 @@ class AnalysisNeighborhoodSearchImpl
         //! Reference position ids (NULL if not available).
         const int              *refid_;
         //! PBC data.
-        t_pbc                  *pbc_;
+        t_pbc                   pbc_;
 
         //! Number of excluded reference positions for current test particle.
         int                     nexcl_;
@@ -237,6 +246,8 @@ class AnalysisNeighborhoodPairSearchImpl
         rvec                                    xtest_;
         //! Stores the previous returned position during a pair loop.
         int                                     previ_;
+        //! Stores the pair distance corresponding to previ_;
+        real                                    prevr2_;
         //! Stores the current exclusion index during loops.
         int                                     exclind_;
         //! Stores the test particle cell index during loops.
@@ -263,11 +274,12 @@ AnalysisNeighborhoodSearchImpl::AnalysisNeighborhoodSearchImpl(real cutoff)
         bTryGrid_   = false;
     }
     cutoff2_        = sqr(cutoff_);
+    bXY_            = false;
 
     nref_           = 0;
     xref_           = NULL;
     refid_          = NULL;
-    pbc_            = NULL;
+    std::memset(&pbc_, 0, sizeof(pbc_));
 
     nexcl_          = 0;
     excl_           = NULL;
@@ -356,16 +368,16 @@ void AnalysisNeighborhoodSearchImpl::initGridCellNeighborList()
     }
 }
 
-bool AnalysisNeighborhoodSearchImpl::initGridCells(const t_pbc *pbc)
+bool AnalysisNeighborhoodSearchImpl::initGridCells(const t_pbc &pbc)
 {
     const real targetsize =
-        pow(pbc->box[XX][XX] * pbc->box[YY][YY] * pbc->box[ZZ][ZZ]
+        pow(pbc.box[XX][XX] * pbc.box[YY][YY] * pbc.box[ZZ][ZZ]
             * 10 / nref_, static_cast<real>(1./3.));
 
     int cellCount = 1;
     for (int dd = 0; dd < DIM; ++dd)
     {
-        ncelldim_[dd] = static_cast<int>(pbc->box[dd][dd] / targetsize);
+        ncelldim_[dd] = static_cast<int>(pbc.box[dd][dd] / targetsize);
         cellCount    *= ncelldim_[dd];
         if (ncelldim_[dd] < 3)
         {
@@ -386,10 +398,10 @@ bool AnalysisNeighborhoodSearchImpl::initGridCells(const t_pbc *pbc)
     return true;
 }
 
-bool AnalysisNeighborhoodSearchImpl::initGrid(const t_pbc *pbc)
+bool AnalysisNeighborhoodSearchImpl::initGrid(const t_pbc &pbc)
 {
     /* TODO: This check could be improved. */
-    if (0.5*pbc->max_cutoff2 < cutoff2_)
+    if (0.5*pbc.max_cutoff2 < cutoff2_)
     {
         return false;
     }
@@ -399,12 +411,12 @@ bool AnalysisNeighborhoodSearchImpl::initGrid(const t_pbc *pbc)
         return false;
     }
 
-    bTric_ = TRICLINIC(pbc->box);
+    bTric_ = TRICLINIC(pbc.box);
     if (bTric_)
     {
         for (int dd = 0; dd < DIM; ++dd)
         {
-            svmul(1.0 / ncelldim_[dd], pbc->box[dd], cellbox_[dd]);
+            svmul(1.0 / ncelldim_[dd], pbc.box[dd], cellbox_[dd]);
         }
         m_inv_ur0(cellbox_, recipcell_);
     }
@@ -412,7 +424,7 @@ bool AnalysisNeighborhoodSearchImpl::initGrid(const t_pbc *pbc)
     {
         for (int dd = 0; dd < DIM; ++dd)
         {
-            cellbox_[dd][dd]   = pbc->box[dd][dd] / ncelldim_[dd];
+            cellbox_[dd][dd]   = pbc.box[dd][dd] / ncelldim_[dd];
             recipcell_[dd][dd] = 1.0 / cellbox_[dd][dd];
         }
     }
@@ -462,17 +474,42 @@ void AnalysisNeighborhoodSearchImpl::addToGridCell(const ivec cell, int i)
 
 void AnalysisNeighborhoodSearchImpl::init(
         AnalysisNeighborhood::SearchMode     mode,
+        bool                                 bXY,
         const t_pbc                         *pbc,
         const AnalysisNeighborhoodPositions &positions)
 {
     GMX_RELEASE_ASSERT(positions.index_ == -1,
                        "Individual indexed positions not supported as reference");
-    pbc_  = const_cast<t_pbc *>(pbc);
+    bXY_ = bXY;
+    if (bXY_ && pbc->ePBC != epbcNONE)
+    {
+        if (pbc->ePBC != epbcXY && pbc->ePBC != epbcXYZ)
+        {
+            std::string message =
+                formatString("Computations in the XY plane are not supported with PBC type '%s'",
+                             EPBC(pbc->ePBC));
+            GMX_THROW(NotImplementedError(message));
+        }
+        if (std::fabs(pbc->box[ZZ][XX]) > GMX_REAL_EPS*pbc->box[ZZ][ZZ] ||
+            std::fabs(pbc->box[ZZ][YY]) > GMX_REAL_EPS*pbc->box[ZZ][ZZ])
+        {
+            GMX_THROW(NotImplementedError("Computations in the XY plane are not supported when the last box vector is not parallel to the Z axis"));
+        }
+        set_pbc(&pbc_, epbcXY, const_cast<rvec *>(pbc->box));
+    }
+    else if (pbc != NULL)
+    {
+        pbc_  = *pbc;
+    }
+    else
+    {
+        pbc_.ePBC = epbcNONE;
+    }
     nref_ = positions.count_;
     // TODO: Consider whether it would be possible to support grid searching in
     // more cases.
     if (mode == AnalysisNeighborhood::eSearchMode_Simple
-        || pbc_ == NULL || pbc_->ePBC != epbcXYZ)
+        || pbc_.ePBC != epbcXYZ)
     {
         bGrid_ = false;
     }
@@ -494,7 +531,7 @@ void AnalysisNeighborhoodSearchImpl::init(
         {
             copy_rvec(positions.x_[i], xref_alloc_[i]);
         }
-        put_atoms_in_triclinic_unitcell(ecenterTRIC, pbc_->box,
+        put_atoms_in_triclinic_unitcell(ecenterTRIC, pbc_.box,
                                         nref_, xref_alloc_);
         for (int i = 0; i < nref_; ++i)
         {
@@ -544,12 +581,14 @@ void AnalysisNeighborhoodPairSearchImpl::reset(int testIndex)
         copy_rvec(testPositions_[testIndex_], xtest_);
         if (search_.bGrid_)
         {
-            put_atoms_in_triclinic_unitcell(ecenterTRIC, search_.pbc_->box,
+            put_atoms_in_triclinic_unitcell(ecenterTRIC,
+                                            const_cast<rvec *>(search_.pbc_.box),
                                             1, &xtest_);
             search_.mapPointToGridCell(xtest_, testcell_);
         }
     }
     previ_     = -1;
+    prevr2_    = 0.0;
     exclind_   = 0;
     prevnbi_   = 0;
     prevcai_   = -1;
@@ -623,6 +662,8 @@ bool AnalysisNeighborhoodPairSearchImpl::searchNext(Action action)
     {
         if (search_.bGrid_)
         {
+            GMX_RELEASE_ASSERT(!search_.bXY_, "Grid-based XY searches not implemented");
+
             int nbi = prevnbi_;
             int cai = prevcai_ + 1;
 
@@ -646,7 +687,7 @@ bool AnalysisNeighborhoodPairSearchImpl::searchNext(Action action)
                         continue;
                     }
                     rvec       dx;
-                    pbc_dx_aiuc(search_.pbc_, xtest_, search_.xref_[i], dx);
+                    pbc_dx_aiuc(&search_.pbc_, xtest_, search_.xref_[i], dx);
                     const real r2 = norm2(dx);
                     if (r2 <= search_.cutoff2_)
                     {
@@ -655,6 +696,7 @@ bool AnalysisNeighborhoodPairSearchImpl::searchNext(Action action)
                             prevnbi_ = nbi;
                             prevcai_ = cai;
                             previ_   = i;
+                            prevr2_  = r2;
                             return true;
                         }
                     }
@@ -672,20 +714,24 @@ bool AnalysisNeighborhoodPairSearchImpl::searchNext(Action action)
                     continue;
                 }
                 rvec dx;
-                if (search_.pbc_)
+                if (search_.pbc_.ePBC != epbcNONE)
                 {
-                    pbc_dx(search_.pbc_, xtest_, search_.xref_[i], dx);
+                    pbc_dx(&search_.pbc_, xtest_, search_.xref_[i], dx);
                 }
                 else
                 {
                     rvec_sub(xtest_, search_.xref_[i], dx);
                 }
-                const real r2 = norm2(dx);
+                const real r2
+                    = search_.bXY_
+                        ? dx[XX]*dx[XX] + dx[YY]*dx[YY]
+                        : norm2(dx);
                 if (r2 <= search_.cutoff2_)
                 {
                     if (action(i, r2))
                     {
-                        previ_ = i;
+                        previ_  = i;
+                        prevr2_ = r2;
                         return true;
                     }
                 }
@@ -705,7 +751,7 @@ void AnalysisNeighborhoodPairSearchImpl::initFoundPair(
     }
     else
     {
-        *pair = AnalysisNeighborhoodPair(previ_, testIndex_);
+        *pair = AnalysisNeighborhoodPair(previ_, testIndex_, prevr2_);
     }
 }
 
@@ -786,7 +832,7 @@ class AnalysisNeighborhood::Impl
         typedef AnalysisNeighborhoodSearch::ImplPointer SearchImplPointer;
         typedef std::vector<SearchImplPointer> SearchList;
 
-        Impl() : cutoff_(0), mode_(eSearchMode_Automatic)
+        Impl() : cutoff_(0), mode_(eSearchMode_Automatic), bXY_(false)
         {
         }
         ~Impl()
@@ -805,6 +851,7 @@ class AnalysisNeighborhood::Impl
         SearchList              searchList_;
         real                    cutoff_;
         SearchMode              mode_;
+        bool                    bXY_;
 };
 
 AnalysisNeighborhood::Impl::SearchImplPointer
@@ -846,6 +893,11 @@ void AnalysisNeighborhood::setCutoff(real cutoff)
     impl_->cutoff_ = cutoff;
 }
 
+void AnalysisNeighborhood::setXYMode(bool bXY)
+{
+    impl_->bXY_ = bXY;
+}
+
 void AnalysisNeighborhood::setMode(SearchMode mode)
 {
     impl_->mode_ = mode;
@@ -861,7 +913,7 @@ AnalysisNeighborhood::initSearch(const t_pbc                         *pbc,
                                  const AnalysisNeighborhoodPositions &positions)
 {
     Impl::SearchImplPointer search(impl_->getSearch());
-    search->init(mode(), pbc, positions);
+    search->init(mode(), impl_->bXY_, pbc, positions);
     return AnalysisNeighborhoodSearch(search);
 }
 
@@ -924,7 +976,7 @@ AnalysisNeighborhoodSearch::nearestPoint(
     int           closestPoint = -1;
     MindistAction action(&closestPoint, &minDist2);
     (void)pairSearch.searchNext(action);
-    return AnalysisNeighborhoodPair(closestPoint, 0);
+    return AnalysisNeighborhoodPair(closestPoint, 0, minDist2);
 }
 
 AnalysisNeighborhoodPairSearch
index 4732ae0b31b24b86d633e724d51ac0044f082adf..79d06b9b02f8c575806151966890cfa1165bae5e 100644 (file)
@@ -196,7 +196,7 @@ class AnalysisNeighborhood
         ~AnalysisNeighborhood();
 
         /*! \brief
-         * Set cutoff distance for the neighborhood searching.
+         * Sets cutoff distance for the neighborhood searching.
          *
          * \param[in]  cutoff Cutoff distance for the search
          *   (<=0 stands for no cutoff).
@@ -207,6 +207,17 @@ class AnalysisNeighborhood
          * Does not throw.
          */
         void setCutoff(real cutoff);
+        /*! \brief
+         * Sets the search to only happen in the XY plane.
+         *
+         * Z component of the coordinates is not used in the searching,
+         * and returned distances are computed in the XY plane.
+         * Only boxes with the third box vector parallel to the Z axis are
+         * currently implemented.
+         *
+         * Does not throw.
+         */
+        void setXYMode(bool bXY);
         /*! \brief
          * Sets the algorithm to use for searching.
          *
@@ -257,10 +268,10 @@ class AnalysisNeighborhoodPair
 {
     public:
         //! Initializes an invalid pair.
-        AnalysisNeighborhoodPair() : refIndex_(-1), testIndex_(0) {}
+        AnalysisNeighborhoodPair() : refIndex_(-1), testIndex_(0), distance2_(0.0) {}
         //! Initializes a pair object with the given data.
-        AnalysisNeighborhoodPair(int refIndex, int testIndex)
-            : refIndex_(refIndex), testIndex_(testIndex)
+        AnalysisNeighborhoodPair(int refIndex, int testIndex, real distance2)
+            : refIndex_(refIndex), testIndex_(testIndex), distance2_(distance2)
         {
         }
 
@@ -294,10 +305,19 @@ class AnalysisNeighborhoodPair
             GMX_ASSERT(isValid(), "Accessing invalid object");
             return testIndex_;
         }
+        /*! \brief
+         * Returns the squared distance between the pair of positions.
+         */
+        real distance2() const
+        {
+            GMX_ASSERT(isValid(), "Accessing invalid object");
+            return distance2_;
+        }
 
     private:
         int                     refIndex_;
         int                     testIndex_;
+        real                    distance2_;
 };
 
 /*! \brief
index 998c80371bf6e3e7f89b0f551e0ea7beab346ba9..f6c162ba2c7b5e8918a6160ff860910fc6d24129 100644 (file)
  * \author Teemu Murtola <teemu.murtola@gmail.com>
  * \ingroup module_selection
  */
+#include "gromacs/selection/nbsearch.h"
+
 #include <gtest/gtest.h>
 
 #include <cmath>
 
+#include <algorithm>
 #include <limits>
-#include <set>
 #include <vector>
 
 #include "gromacs/math/vec.h"
 #include "gromacs/pbcutil/pbc.h"
 #include "gromacs/random/random.h"
-#include "gromacs/selection/nbsearch.h"
 #include "gromacs/utility/smalloc.h"
 
 #include "testutils/testasserts.h"
@@ -70,6 +71,23 @@ namespace
 class NeighborhoodSearchTestData
 {
     public:
+        struct RefPair
+        {
+            RefPair(int refIndex, real distance)
+                : refIndex(refIndex), distance(distance), bFound(false)
+            {
+            }
+
+            bool operator<(const RefPair &other) const
+            {
+                return refIndex < other.refIndex;
+            }
+
+            int                 refIndex;
+            real                distance;
+            bool                bFound;
+        };
+
         struct TestPosition
         {
             TestPosition() : refMinDist(0.0), refNearestPoint(-1)
@@ -82,11 +100,12 @@ class NeighborhoodSearchTestData
                 copy_rvec(x, this->x);
             }
 
-            rvec                x;
-            real                refMinDist;
-            int                 refNearestPoint;
-            std::set<int>       refPairs;
+            rvec                 x;
+            real                 refMinDist;
+            int                  refNearestPoint;
+            std::vector<RefPair> refPairs;
         };
+
         typedef std::vector<TestPosition> TestPositionList;
 
         NeighborhoodSearchTestData(int seed, real cutoff);
@@ -123,7 +142,26 @@ class NeighborhoodSearchTestData
         void generateRandomPosition(rvec x);
         void generateRandomRefPositions(int count);
         void generateRandomTestPositions(int count);
-        void computeReferences(t_pbc *pbc);
+        void computeReferences(t_pbc *pbc)
+        {
+            computeReferencesInternal(pbc, false);
+        }
+        void computeReferencesXY(t_pbc *pbc)
+        {
+            computeReferencesInternal(pbc, true);
+        }
+
+        bool containsPair(int testIndex, const RefPair &pair) const
+        {
+            const std::vector<RefPair>          &refPairs = testPositions_[testIndex].refPairs;
+            std::vector<RefPair>::const_iterator foundRefPair
+                = std::lower_bound(refPairs.begin(), refPairs.end(), pair);
+            if (foundRefPair == refPairs.end() || foundRefPair->refIndex != pair.refIndex)
+            {
+                return false;
+            }
+            return true;
+        }
 
         gmx_rng_t                        rng_;
         real                             cutoff_;
@@ -134,9 +172,14 @@ class NeighborhoodSearchTestData
         TestPositionList                 testPositions_;
 
     private:
+        void computeReferencesInternal(t_pbc *pbc, bool bXY);
+
         mutable rvec                    *testPos_;
 };
 
+//! Shorthand for a collection of reference pairs.
+typedef std::vector<NeighborhoodSearchTestData::RefPair> RefPairList;
+
 NeighborhoodSearchTestData::NeighborhoodSearchTestData(int seed, real cutoff)
     : rng_(NULL), cutoff_(cutoff), refPosCount_(0), refPos_(NULL), testPos_(NULL)
 {
@@ -190,7 +233,7 @@ void NeighborhoodSearchTestData::generateRandomTestPositions(int count)
     }
 }
 
-void NeighborhoodSearchTestData::computeReferences(t_pbc *pbc)
+void NeighborhoodSearchTestData::computeReferencesInternal(t_pbc *pbc, bool bXY)
 {
     real cutoff = cutoff_;
     if (cutoff <= 0)
@@ -214,7 +257,11 @@ void NeighborhoodSearchTestData::computeReferences(t_pbc *pbc)
             {
                 rvec_sub(i->x, refPos_[j], dx);
             }
-            const real dist = norm(dx);
+            // TODO: This may not work intuitively for 2D with the third box
+            // vector not parallel to the Z axis, but neither does the actual
+            // neighborhood search.
+            const real dist =
+                !bXY ? norm(dx) : sqrt(sqr(dx[XX]) + sqr(dx[YY]));
             if (dist < i->refMinDist)
             {
                 i->refMinDist      = dist;
@@ -222,7 +269,10 @@ void NeighborhoodSearchTestData::computeReferences(t_pbc *pbc)
             }
             if (dist <= cutoff)
             {
-                i->refPairs.insert(j);
+                RefPair pair(j, dist);
+                GMX_RELEASE_ASSERT(i->refPairs.empty() || i->refPairs.back() < pair,
+                                   "Reference pairs should be generated in sorted order");
+                i->refPairs.push_back(pair);
             }
         }
     }
@@ -285,6 +335,8 @@ void NeighborhoodSearchTest::testNearestPoint(
         {
             EXPECT_EQ(i->refNearestPoint, pair.refIndex());
             EXPECT_EQ(0, pair.testIndex());
+            EXPECT_REAL_EQ_TOL(i->refMinDist, sqrt(pair.distance2()),
+                               gmx::test::ulpTolerance(64));
         }
         else
         {
@@ -298,26 +350,53 @@ void NeighborhoodSearchTest::testPairSearch(
         const NeighborhoodSearchTestData &data)
 {
     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
+    // TODO: Test also searching all the test positions in a single search;
+    // currently the implementation just contains this loop, though, but in
+    // the future that may trigger a different code path.
     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
     {
-        std::set<int> checkSet                         = i->refPairs;
-        gmx::AnalysisNeighborhoodPairSearch pairSearch =
-            search->startPairSearch(i->x);
+        RefPairList                         refPairs = i->refPairs;
+        gmx::AnalysisNeighborhoodPairSearch pairSearch
+            search->startPairSearch(i->x);
         gmx::AnalysisNeighborhoodPair       pair;
         while (pairSearch.findNextPair(&pair))
         {
             EXPECT_EQ(0, pair.testIndex());
-            if (checkSet.erase(pair.refIndex()) == 0)
+            NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(),
+                                                           sqrt(pair.distance2()));
+            RefPairList::iterator               foundRefPair
+                = std::lower_bound(refPairs.begin(), refPairs.end(), searchPair);
+            if (foundRefPair == refPairs.end() || foundRefPair->refIndex != pair.refIndex())
             {
-                // TODO: Check whether the same pair was returned more than
-                // once and give a better error message if so.
                 ADD_FAILURE()
                 << "Expected: Position " << pair.refIndex()
                 << " is within cutoff.\n"
                 << "  Actual: It is not.";
             }
+            else if (foundRefPair->bFound)
+            {
+                ADD_FAILURE()
+                << "Expected: Position " << pair.refIndex()
+                << " is returned only once.\n"
+                << "  Actual: It is returned multiple times.";
+            }
+            else
+            {
+                foundRefPair->bFound = true;
+                EXPECT_REAL_EQ_TOL(foundRefPair->distance, searchPair.distance,
+                                   gmx::test::ulpTolerance(64));
+            }
+        }
+        for (RefPairList::const_iterator j = refPairs.begin(); j != refPairs.end(); ++j)
+        {
+            if (!j->bFound)
+            {
+                ADD_FAILURE()
+                << "Expected: All pairs within cutoff will be returned.\n"
+                << "  Actual: Position " << j->refIndex << " is not found.";
+                break;
+            }
         }
-        EXPECT_TRUE(checkSet.empty()) << "Some positions were not returned by the pair search.";
     }
 }
 
@@ -375,6 +454,32 @@ class RandomBoxFullPBCData
         NeighborhoodSearchTestData data_;
 };
 
+class RandomBoxXYFullPBCData
+{
+    public:
+        static const NeighborhoodSearchTestData &get()
+        {
+            static RandomBoxXYFullPBCData singleton;
+            return singleton.data_;
+        }
+
+        RandomBoxXYFullPBCData() : data_(54321, 1.0)
+        {
+            data_.box_[XX][XX] = 10.0;
+            data_.box_[YY][YY] = 5.0;
+            data_.box_[ZZ][ZZ] = 7.0;
+            // TODO: Consider whether manually picking some positions would give better
+            // test coverage.
+            data_.generateRandomRefPositions(1000);
+            data_.generateRandomTestPositions(100);
+            set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
+            data_.computeReferencesXY(&data_.pbc_);
+        }
+
+    private:
+        NeighborhoodSearchTestData data_;
+};
+
 class RandomTriclinicFullPBCData
 {
     public:
@@ -496,6 +601,24 @@ TEST_F(NeighborhoodSearchTest, GridSearch2DPBC)
     testPairSearch(&search, data);
 }
 
+TEST_F(NeighborhoodSearchTest, GridSearchXYBox)
+{
+    const NeighborhoodSearchTestData &data = RandomBoxXYFullPBCData::get();
+
+    nb_.setCutoff(data.cutoff_);
+    nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
+    nb_.setXYMode(true);
+    gmx::AnalysisNeighborhoodSearch search =
+        nb_.initSearch(&data.pbc_, data.refPositions());
+    // Currently, grid searching not supported with XY.
+    //ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
+
+    testIsWithin(&search, data);
+    testMinimumDistance(&search, data);
+    testNearestPoint(&search, data);
+    testPairSearch(&search, data);
+}
+
 TEST_F(NeighborhoodSearchTest, HandlesConcurrentSearches)
 {
     const NeighborhoodSearchTestData &data = TrivialTestData::get();
@@ -516,11 +639,17 @@ TEST_F(NeighborhoodSearchTest, HandlesConcurrentSearches)
     gmx::AnalysisNeighborhoodPair pair;
     pairSearch1.findNextPair(&pair);
     EXPECT_EQ(0, pair.testIndex());
-    EXPECT_TRUE(data.testPositions_[0].refPairs.count(pair.refIndex()) == 1);
+    {
+        NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), sqrt(pair.distance2()));
+        EXPECT_TRUE(data.containsPair(0, searchPair));
+    }
 
     pairSearch2.findNextPair(&pair);
     EXPECT_EQ(1, pair.testIndex());
-    EXPECT_TRUE(data.testPositions_[1].refPairs.count(pair.refIndex()) == 1);
+    {
+        NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), sqrt(pair.distance2()));
+        EXPECT_TRUE(data.containsPair(1, searchPair));
+    }
 }
 
 TEST_F(NeighborhoodSearchTest, HandlesSkippingPairs)
@@ -545,7 +674,8 @@ TEST_F(NeighborhoodSearchTest, HandlesSkippingPairs)
             ++currentIndex;
         }
         EXPECT_EQ(currentIndex, pair.testIndex());
-        EXPECT_TRUE(data.testPositions_[currentIndex].refPairs.count(pair.refIndex()) == 1);
+        NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), sqrt(pair.distance2()));
+        EXPECT_TRUE(data.containsPair(currentIndex, searchPair));
         pairSearch.skipRemainingPairsForTestPosition();
         ++currentIndex;
     }