d0f61f30f7d934a480198c574a3c12e382fe3c7f
[alexxy/gromacs.git] / src / gromacs / selection / tests / nbsearch.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2013,2014,2015,2016,2017 by the GROMACS development team.
5  * Copyright (c) 2018,2019,2020, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36 /*! \internal \file
37  * \brief
38  * Tests selection neighborhood searching.
39  *
40  * \todo
41  * Increase coverage of these tests for different corner cases: other PBC cases
42  * than full 3D, large cutoffs (larger than half the box size), etc.
43  * At least some of these probably don't work correctly.
44  *
45  * \author Teemu Murtola <teemu.murtola@gmail.com>
46  * \ingroup module_selection
47  */
48 #include "gmxpre.h"
49
50 #include "gromacs/selection/nbsearch.h"
51
52 #include <cmath>
53
54 #include <algorithm>
55 #include <limits>
56 #include <map>
57 #include <numeric>
58 #include <vector>
59
60 #include <gtest/gtest.h>
61
62 #include "gromacs/math/functions.h"
63 #include "gromacs/math/vec.h"
64 #include "gromacs/pbcutil/pbc.h"
65 #include "gromacs/random/threefry.h"
66 #include "gromacs/random/uniformrealdistribution.h"
67 #include "gromacs/topology/block.h"
68 #include "gromacs/utility/listoflists.h"
69 #include "gromacs/utility/smalloc.h"
70 #include "gromacs/utility/stringutil.h"
71
72 #include "testutils/testasserts.h"
73
74
75 namespace
76 {
77
78 /********************************************************************
79  * NeighborhoodSearchTestData
80  */
81
82 class NeighborhoodSearchTestData
83 {
84 public:
85     struct RefPair
86     {
87         RefPair(int refIndex, real distance) :
88             refIndex(refIndex),
89             distance(distance),
90             bFound(false),
91             bExcluded(false),
92             bIndexed(true)
93         {
94         }
95
96         bool operator<(const RefPair& other) const { return refIndex < other.refIndex; }
97
98         int  refIndex;
99         real distance;
100         // The variables below are state variables that are only used
101         // during the actual testing after creating a copy of the reference
102         // pair list, not as part of the reference data.
103         // Simpler to have just a single structure for both purposes.
104         bool bFound;
105         bool bExcluded;
106         bool bIndexed;
107     };
108
109     struct TestPosition
110     {
111         explicit TestPosition(const rvec x) : refMinDist(0.0), refNearestPoint(-1)
112         {
113             copy_rvec(x, this->x);
114         }
115
116         rvec                 x;
117         real                 refMinDist;
118         int                  refNearestPoint;
119         std::vector<RefPair> refPairs;
120     };
121
122     typedef std::vector<TestPosition> TestPositionList;
123
124     NeighborhoodSearchTestData(uint64_t seed, real cutoff);
125
126     gmx::AnalysisNeighborhoodPositions refPositions() const
127     {
128         return gmx::AnalysisNeighborhoodPositions(refPos_);
129     }
130     gmx::AnalysisNeighborhoodPositions testPositions() const
131     {
132         if (testPos_.empty())
133         {
134             testPos_.reserve(testPositions_.size());
135             for (size_t i = 0; i < testPositions_.size(); ++i)
136             {
137                 testPos_.emplace_back(testPositions_[i].x);
138             }
139         }
140         return gmx::AnalysisNeighborhoodPositions(testPos_);
141     }
142     gmx::AnalysisNeighborhoodPositions testPosition(int index) const
143     {
144         return testPositions().selectSingleFromArray(index);
145     }
146
147     void addTestPosition(const rvec x)
148     {
149         GMX_RELEASE_ASSERT(testPos_.empty(), "Cannot add positions after testPositions() call");
150         testPositions_.emplace_back(x);
151     }
152     gmx::RVec               generateRandomPosition();
153     static std::vector<int> generateIndex(int count, uint64_t seed);
154     void                    generateRandomRefPositions(int count);
155     void                    generateRandomTestPositions(int count);
156     void                    useRefPositionsAsTestPositions();
157     void                    computeReferences(t_pbc* pbc) { computeReferencesInternal(pbc, false); }
158     void computeReferencesXY(t_pbc* pbc) { computeReferencesInternal(pbc, true); }
159
160     bool containsPair(int testIndex, const RefPair& pair) const
161     {
162         const std::vector<RefPair>&          refPairs = testPositions_[testIndex].refPairs;
163         std::vector<RefPair>::const_iterator foundRefPair =
164                 std::lower_bound(refPairs.begin(), refPairs.end(), pair);
165         return !(foundRefPair == refPairs.end() || foundRefPair->refIndex != pair.refIndex);
166     }
167
168     // Return a tolerance that accounts for the magnitudes of the coordinates
169     // when doing subtractions, so that we set the ULP tolerance relative to the
170     // coordinate values rather than their difference.
171     // i.e., 10.0-9.9999999 will achieve a few ULP accuracy relative
172     // to 10.0, but a much larger error relative to the difference.
173     gmx::test::FloatingPointTolerance relativeTolerance() const
174     {
175         real magnitude = std::max(box_[XX][XX], std::max(box_[YY][YY], box_[ZZ][ZZ]));
176         return gmx::test::relativeToleranceAsUlp(magnitude, 4);
177     }
178
179     gmx::DefaultRandomEngine rng_;
180     real                     cutoff_;
181     matrix                   box_;
182     t_pbc                    pbc_;
183     int                      refPosCount_;
184     std::vector<gmx::RVec>   refPos_;
185     TestPositionList         testPositions_;
186
187 private:
188     void computeReferencesInternal(t_pbc* pbc, bool bXY);
189
190     mutable std::vector<gmx::RVec> testPos_;
191 };
192
193 //! Shorthand for a collection of reference pairs.
194 typedef std::vector<NeighborhoodSearchTestData::RefPair> RefPairList;
195
196 NeighborhoodSearchTestData::NeighborhoodSearchTestData(uint64_t seed, real cutoff) :
197     rng_(seed),
198     cutoff_(cutoff),
199     refPosCount_(0)
200 {
201     clear_mat(box_);
202     set_pbc(&pbc_, PbcType::No, box_);
203 }
204
205 gmx::RVec NeighborhoodSearchTestData::generateRandomPosition()
206 {
207     gmx::UniformRealDistribution<real> dist;
208     rvec                               fx, x;
209     fx[XX] = dist(rng_);
210     fx[YY] = dist(rng_);
211     fx[ZZ] = dist(rng_);
212     mvmul(box_, fx, x);
213     // Add a small displacement to allow positions outside the box
214     x[XX] += 0.2 * dist(rng_) - 0.1;
215     x[YY] += 0.2 * dist(rng_) - 0.1;
216     x[ZZ] += 0.2 * dist(rng_) - 0.1;
217     return x;
218 }
219
220 std::vector<int> NeighborhoodSearchTestData::generateIndex(int count, uint64_t seed)
221 {
222     gmx::DefaultRandomEngine           rngIndex(seed);
223     gmx::UniformRealDistribution<real> dist;
224     std::vector<int>                   result;
225
226     for (int i = 0; i < count; ++i)
227     {
228         if (dist(rngIndex) > 0.5)
229         {
230             result.push_back(i);
231         }
232     }
233     return result;
234 }
235
236 void NeighborhoodSearchTestData::generateRandomRefPositions(int count)
237 {
238     refPosCount_ = count;
239     refPos_.reserve(count);
240     for (int i = 0; i < count; ++i)
241     {
242         refPos_.push_back(generateRandomPosition());
243     }
244 }
245
246 void NeighborhoodSearchTestData::generateRandomTestPositions(int count)
247 {
248     testPositions_.reserve(count);
249     for (int i = 0; i < count; ++i)
250     {
251         addTestPosition(generateRandomPosition());
252     }
253 }
254
255 void NeighborhoodSearchTestData::useRefPositionsAsTestPositions()
256 {
257     testPositions_.reserve(refPosCount_);
258     for (const auto& refPos : refPos_)
259     {
260         addTestPosition(refPos);
261     }
262 }
263
264 void NeighborhoodSearchTestData::computeReferencesInternal(t_pbc* pbc, bool bXY)
265 {
266     real cutoff = cutoff_;
267     if (cutoff <= 0)
268     {
269         cutoff = std::numeric_limits<real>::max();
270     }
271     for (TestPosition& testPos : testPositions_)
272     {
273         testPos.refMinDist      = cutoff;
274         testPos.refNearestPoint = -1;
275         testPos.refPairs.clear();
276         for (int j = 0; j < refPosCount_; ++j)
277         {
278             rvec dx;
279             if (pbc != nullptr)
280             {
281                 pbc_dx(pbc, testPos.x, refPos_[j], dx);
282             }
283             else
284             {
285                 rvec_sub(testPos.x, refPos_[j], dx);
286             }
287             // TODO: This may not work intuitively for 2D with the third box
288             // vector not parallel to the Z axis, but neither does the actual
289             // neighborhood search.
290             const real dist = !bXY ? norm(dx) : std::hypot(dx[XX], dx[YY]);
291             if (dist < testPos.refMinDist)
292             {
293                 testPos.refMinDist      = dist;
294                 testPos.refNearestPoint = j;
295             }
296             if (dist > 0 && dist <= cutoff)
297             {
298                 RefPair pair(j, dist);
299                 GMX_RELEASE_ASSERT(testPos.refPairs.empty() || testPos.refPairs.back() < pair,
300                                    "Reference pairs should be generated in sorted order");
301                 testPos.refPairs.push_back(pair);
302             }
303         }
304     }
305 }
306
307 /********************************************************************
308  * ExclusionsHelper
309  */
310
311 class ExclusionsHelper
312 {
313 public:
314     static void markExcludedPairs(RefPairList* refPairs, int testIndex, const gmx::ListOfLists<int>* excls);
315
316     ExclusionsHelper(int refPosCount, int testPosCount);
317
318     void generateExclusions();
319
320     const gmx::ListOfLists<int>* exclusions() const { return &excls_; }
321
322     gmx::ArrayRef<const int> refPosIds() const
323     {
324         return gmx::makeArrayRef(exclusionIds_).subArray(0, refPosCount_);
325     }
326     gmx::ArrayRef<const int> testPosIds() const
327     {
328         return gmx::makeArrayRef(exclusionIds_).subArray(0, testPosCount_);
329     }
330
331 private:
332     int                   refPosCount_;
333     int                   testPosCount_;
334     std::vector<int>      exclusionIds_;
335     gmx::ListOfLists<int> excls_;
336 };
337
338 // static
339 void ExclusionsHelper::markExcludedPairs(RefPairList* refPairs, int testIndex, const gmx::ListOfLists<int>* excls)
340 {
341     int count = 0;
342     for (const int excludedIndex : (*excls)[testIndex])
343     {
344         NeighborhoodSearchTestData::RefPair searchPair(excludedIndex, 0.0);
345         RefPairList::iterator               excludedRefPair =
346                 std::lower_bound(refPairs->begin(), refPairs->end(), searchPair);
347         if (excludedRefPair != refPairs->end() && excludedRefPair->refIndex == excludedIndex)
348         {
349             excludedRefPair->bFound    = true;
350             excludedRefPair->bExcluded = true;
351             ++count;
352         }
353     }
354 }
355
356 ExclusionsHelper::ExclusionsHelper(int refPosCount, int testPosCount) :
357     refPosCount_(refPosCount),
358     testPosCount_(testPosCount)
359 {
360     // Generate an array of 0, 1, 2, ...
361     // TODO: Make the tests work also with non-trivial exclusion IDs,
362     // and test that.
363     exclusionIds_.resize(std::max(refPosCount, testPosCount), 1);
364     exclusionIds_[0] = 0;
365     std::partial_sum(exclusionIds_.begin(), exclusionIds_.end(), exclusionIds_.begin());
366 }
367
368 void ExclusionsHelper::generateExclusions()
369 {
370     // TODO: Consider a better set of test data, where the density of the
371     // particles would be higher, or where the exclusions would not be random,
372     // to make a higher percentage of the exclusions to actually be within the
373     // cutoff.
374     for (int i = 0; i < testPosCount_; ++i)
375     {
376         excls_.pushBackListOfSize(20);
377         gmx::ArrayRef<int> exclusionsForAtom = excls_.back();
378         for (int j = 0; j < 20; ++j)
379         {
380             exclusionsForAtom[j] = i + j * 3;
381         }
382     }
383 }
384
385 /********************************************************************
386  * NeighborhoodSearchTest
387  */
388
389 class NeighborhoodSearchTest : public ::testing::Test
390 {
391 public:
392     static void testIsWithin(gmx::AnalysisNeighborhoodSearch*  search,
393                              const NeighborhoodSearchTestData& data);
394     static void testMinimumDistance(gmx::AnalysisNeighborhoodSearch*  search,
395                                     const NeighborhoodSearchTestData& data);
396     static void testNearestPoint(gmx::AnalysisNeighborhoodSearch*  search,
397                                  const NeighborhoodSearchTestData& data);
398     static void testPairSearch(gmx::AnalysisNeighborhoodSearch*  search,
399                                const NeighborhoodSearchTestData& data);
400     static void testPairSearchIndexed(gmx::AnalysisNeighborhood*        nb,
401                                       const NeighborhoodSearchTestData& data,
402                                       uint64_t                          seed);
403     static void testPairSearchFull(gmx::AnalysisNeighborhoodSearch*          search,
404                                    const NeighborhoodSearchTestData&         data,
405                                    const gmx::AnalysisNeighborhoodPositions& pos,
406                                    const gmx::ListOfLists<int>*              excls,
407                                    const gmx::ArrayRef<const int>&           refIndices,
408                                    const gmx::ArrayRef<const int>&           testIndices,
409                                    bool                                      selfPairs);
410
411     gmx::AnalysisNeighborhood nb_;
412 };
413
414 void NeighborhoodSearchTest::testIsWithin(gmx::AnalysisNeighborhoodSearch*  search,
415                                           const NeighborhoodSearchTestData& data)
416 {
417     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
418     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
419     {
420         const bool bWithin = (i->refMinDist <= data.cutoff_);
421         EXPECT_EQ(bWithin, search->isWithin(i->x)) << "Distance is " << i->refMinDist;
422     }
423 }
424
425 void NeighborhoodSearchTest::testMinimumDistance(gmx::AnalysisNeighborhoodSearch*  search,
426                                                  const NeighborhoodSearchTestData& data)
427 {
428     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
429
430     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
431     {
432         const real refDist = i->refMinDist;
433         EXPECT_REAL_EQ_TOL(refDist, search->minimumDistance(i->x), data.relativeTolerance());
434     }
435 }
436
437 void NeighborhoodSearchTest::testNearestPoint(gmx::AnalysisNeighborhoodSearch*  search,
438                                               const NeighborhoodSearchTestData& data)
439 {
440     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
441     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
442     {
443         const gmx::AnalysisNeighborhoodPair pair = search->nearestPoint(i->x);
444         if (pair.isValid())
445         {
446             EXPECT_EQ(i->refNearestPoint, pair.refIndex());
447             EXPECT_EQ(0, pair.testIndex());
448             EXPECT_REAL_EQ_TOL(i->refMinDist, std::sqrt(pair.distance2()), data.relativeTolerance());
449         }
450         else
451         {
452             EXPECT_EQ(i->refNearestPoint, -1);
453         }
454     }
455 }
456
457 //! Helper function for formatting test failure messages.
458 std::string formatVector(const rvec x)
459 {
460     return gmx::formatString("[%.3f, %.3f, %.3f]", x[XX], x[YY], x[ZZ]);
461 }
462
463 /*! \brief
464  * Helper function to check that all expected pairs were found.
465  */
466 void checkAllPairsFound(const RefPairList&            refPairs,
467                         const std::vector<gmx::RVec>& refPos,
468                         int                           testPosIndex,
469                         const rvec                    testPos)
470 {
471     // This could be elegantly expressed with Google Mock matchers, but that
472     // has a significant effect on the runtime of the tests...
473     int                         count = 0;
474     RefPairList::const_iterator first;
475     for (RefPairList::const_iterator i = refPairs.begin(); i != refPairs.end(); ++i)
476     {
477         if (!i->bFound)
478         {
479             ++count;
480             first = i;
481         }
482     }
483     if (count > 0)
484     {
485         ADD_FAILURE() << "Some pairs (" << count << "/" << refPairs.size() << ") "
486                       << "within the cutoff were not found. First pair:\n"
487                       << " Ref: " << first->refIndex << " at "
488                       << formatVector(refPos[first->refIndex]) << "\n"
489                       << "Test: " << testPosIndex << " at " << formatVector(testPos) << "\n"
490                       << "Dist: " << first->distance;
491     }
492 }
493
494 void NeighborhoodSearchTest::testPairSearch(gmx::AnalysisNeighborhoodSearch*  search,
495                                             const NeighborhoodSearchTestData& data)
496 {
497     testPairSearchFull(search, data, data.testPositions(), nullptr, {}, {}, false);
498 }
499
500 void NeighborhoodSearchTest::testPairSearchIndexed(gmx::AnalysisNeighborhood*        nb,
501                                                    const NeighborhoodSearchTestData& data,
502                                                    uint64_t                          seed)
503 {
504     std::vector<int> refIndices(data.generateIndex(data.refPos_.size(), seed++));
505     std::vector<int> testIndices(data.generateIndex(data.testPositions_.size(), seed++));
506     gmx::AnalysisNeighborhoodSearch search =
507             nb->initSearch(&data.pbc_, data.refPositions().indexed(refIndices));
508     testPairSearchFull(&search, data, data.testPositions(), nullptr, refIndices, testIndices, false);
509 }
510
511 void NeighborhoodSearchTest::testPairSearchFull(gmx::AnalysisNeighborhoodSearch*          search,
512                                                 const NeighborhoodSearchTestData&         data,
513                                                 const gmx::AnalysisNeighborhoodPositions& pos,
514                                                 const gmx::ListOfLists<int>*              excls,
515                                                 const gmx::ArrayRef<const int>& refIndices,
516                                                 const gmx::ArrayRef<const int>& testIndices,
517                                                 bool                            selfPairs)
518 {
519     std::map<int, RefPairList> refPairs;
520     // TODO: Some parts of this code do not work properly if pos does not
521     // initially contain all the test positions.
522     if (testIndices.empty())
523     {
524         for (size_t i = 0; i < data.testPositions_.size(); ++i)
525         {
526             refPairs[i] = data.testPositions_[i].refPairs;
527         }
528     }
529     else
530     {
531         for (int index : testIndices)
532         {
533             refPairs[index] = data.testPositions_[index].refPairs;
534         }
535     }
536     if (excls != nullptr)
537     {
538         GMX_RELEASE_ASSERT(!selfPairs, "Self-pairs testing not implemented with exclusions");
539         for (auto& entry : refPairs)
540         {
541             const int testIndex = entry.first;
542             ExclusionsHelper::markExcludedPairs(&entry.second, testIndex, excls);
543         }
544     }
545     if (!refIndices.empty())
546     {
547         GMX_RELEASE_ASSERT(!selfPairs, "Self-pairs testing not implemented with indexing");
548         for (auto& entry : refPairs)
549         {
550             for (auto& refPair : entry.second)
551             {
552                 refPair.bIndexed = false;
553             }
554             for (int index : refIndices)
555             {
556                 NeighborhoodSearchTestData::RefPair searchPair(index, 0.0);
557                 auto refPair = std::lower_bound(entry.second.begin(), entry.second.end(), searchPair);
558                 if (refPair != entry.second.end() && refPair->refIndex == index)
559                 {
560                     refPair->bIndexed = true;
561                 }
562             }
563             for (auto& refPair : entry.second)
564             {
565                 if (!refPair.bIndexed)
566                 {
567                     refPair.bFound = true;
568                 }
569             }
570         }
571     }
572
573     gmx::AnalysisNeighborhoodPositions posCopy(pos);
574     if (!testIndices.empty())
575     {
576         posCopy.indexed(testIndices);
577     }
578     gmx::AnalysisNeighborhoodPairSearch pairSearch =
579             selfPairs ? search->startSelfPairSearch() : search->startPairSearch(posCopy);
580     gmx::AnalysisNeighborhoodPair pair;
581     while (pairSearch.findNextPair(&pair))
582     {
583         const int testIndex = (testIndices.empty() ? pair.testIndex() : testIndices[pair.testIndex()]);
584         const int refIndex = (refIndices.empty() ? pair.refIndex() : refIndices[pair.refIndex()]);
585
586         if (refPairs.count(testIndex) == 0)
587         {
588             ADD_FAILURE() << "Expected: No pairs are returned for test position " << testIndex << ".\n"
589                           << "  Actual: Pair with ref " << refIndex << " is returned.";
590             continue;
591         }
592         NeighborhoodSearchTestData::RefPair searchPair(refIndex, std::sqrt(pair.distance2()));
593         const auto                          foundRefPair =
594                 std::lower_bound(refPairs[testIndex].begin(), refPairs[testIndex].end(), searchPair);
595         if (foundRefPair == refPairs[testIndex].end() || foundRefPair->refIndex != refIndex)
596         {
597             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
598                           << ") is not within the cutoff.\n"
599                           << "  Actual: It is returned.";
600         }
601         else if (foundRefPair->bExcluded)
602         {
603             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
604                           << ") is excluded from the search.\n"
605                           << "  Actual: It is returned.";
606         }
607         else if (!foundRefPair->bIndexed)
608         {
609             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
610                           << ") is not part of the indexed set.\n"
611                           << "  Actual: It is returned.";
612         }
613         else if (foundRefPair->bFound)
614         {
615             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
616                           << ") is returned only once.\n"
617                           << "  Actual: It is returned multiple times.";
618             return;
619         }
620         else
621         {
622             foundRefPair->bFound = true;
623
624             EXPECT_REAL_EQ_TOL(foundRefPair->distance, searchPair.distance, data.relativeTolerance())
625                     << "Distance computed by the neighborhood search does not match.";
626             if (selfPairs)
627             {
628                 searchPair              = NeighborhoodSearchTestData::RefPair(testIndex, 0.0);
629                 const auto otherRefPair = std::lower_bound(refPairs[refIndex].begin(),
630                                                            refPairs[refIndex].end(), searchPair);
631                 GMX_RELEASE_ASSERT(otherRefPair != refPairs[refIndex].end(),
632                                    "Precomputed reference data is not symmetric");
633                 otherRefPair->bFound = true;
634             }
635         }
636     }
637
638     for (auto& entry : refPairs)
639     {
640         const int testIndex = entry.first;
641         checkAllPairsFound(entry.second, data.refPos_, testIndex, data.testPositions_[testIndex].x);
642     }
643 }
644
645 /********************************************************************
646  * Test data generation
647  */
648
649 class TrivialTestData
650 {
651 public:
652     static const NeighborhoodSearchTestData& get()
653     {
654         static TrivialTestData singleton;
655         return singleton.data_;
656     }
657
658     TrivialTestData() : data_(12345, 1.0)
659     {
660         // Make the box so small we are virtually guaranteed to have
661         // several neighbors for the five test positions
662         data_.box_[XX][XX] = 3.0;
663         data_.box_[YY][YY] = 3.0;
664         data_.box_[ZZ][ZZ] = 3.0;
665         data_.generateRandomRefPositions(10);
666         data_.generateRandomTestPositions(5);
667         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
668         data_.computeReferences(&data_.pbc_);
669     }
670
671 private:
672     NeighborhoodSearchTestData data_;
673 };
674
675 class TrivialSelfPairsTestData
676 {
677 public:
678     static const NeighborhoodSearchTestData& get()
679     {
680         static TrivialSelfPairsTestData singleton;
681         return singleton.data_;
682     }
683
684     TrivialSelfPairsTestData() : data_(12345, 1.0)
685     {
686         data_.box_[XX][XX] = 3.0;
687         data_.box_[YY][YY] = 3.0;
688         data_.box_[ZZ][ZZ] = 3.0;
689         data_.generateRandomRefPositions(20);
690         data_.useRefPositionsAsTestPositions();
691         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
692         data_.computeReferences(&data_.pbc_);
693     }
694
695 private:
696     NeighborhoodSearchTestData data_;
697 };
698
699 class TrivialNoPBCTestData
700 {
701 public:
702     static const NeighborhoodSearchTestData& get()
703     {
704         static TrivialNoPBCTestData singleton;
705         return singleton.data_;
706     }
707
708     TrivialNoPBCTestData() : data_(12345, 1.0)
709     {
710         data_.generateRandomRefPositions(10);
711         data_.generateRandomTestPositions(5);
712         data_.computeReferences(nullptr);
713     }
714
715 private:
716     NeighborhoodSearchTestData data_;
717 };
718
719 class RandomBoxFullPBCData
720 {
721 public:
722     static const NeighborhoodSearchTestData& get()
723     {
724         static RandomBoxFullPBCData singleton;
725         return singleton.data_;
726     }
727
728     RandomBoxFullPBCData() : data_(12345, 1.0)
729     {
730         data_.box_[XX][XX] = 10.0;
731         data_.box_[YY][YY] = 5.0;
732         data_.box_[ZZ][ZZ] = 7.0;
733         // TODO: Consider whether manually picking some positions would give better
734         // test coverage.
735         data_.generateRandomRefPositions(1000);
736         data_.generateRandomTestPositions(100);
737         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
738         data_.computeReferences(&data_.pbc_);
739     }
740
741 private:
742     NeighborhoodSearchTestData data_;
743 };
744
745 class RandomBoxSelfPairsData
746 {
747 public:
748     static const NeighborhoodSearchTestData& get()
749     {
750         static RandomBoxSelfPairsData singleton;
751         return singleton.data_;
752     }
753
754     RandomBoxSelfPairsData() : data_(12345, 1.0)
755     {
756         data_.box_[XX][XX] = 10.0;
757         data_.box_[YY][YY] = 5.0;
758         data_.box_[ZZ][ZZ] = 7.0;
759         data_.generateRandomRefPositions(1000);
760         data_.useRefPositionsAsTestPositions();
761         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
762         data_.computeReferences(&data_.pbc_);
763     }
764
765 private:
766     NeighborhoodSearchTestData data_;
767 };
768
769 class RandomBoxXYFullPBCData
770 {
771 public:
772     static const NeighborhoodSearchTestData& get()
773     {
774         static RandomBoxXYFullPBCData singleton;
775         return singleton.data_;
776     }
777
778     RandomBoxXYFullPBCData() : data_(54321, 1.0)
779     {
780         data_.box_[XX][XX] = 10.0;
781         data_.box_[YY][YY] = 5.0;
782         data_.box_[ZZ][ZZ] = 7.0;
783         // TODO: Consider whether manually picking some positions would give better
784         // test coverage.
785         data_.generateRandomRefPositions(1000);
786         data_.generateRandomTestPositions(100);
787         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
788         data_.computeReferencesXY(&data_.pbc_);
789     }
790
791 private:
792     NeighborhoodSearchTestData data_;
793 };
794
795 class RandomTriclinicFullPBCData
796 {
797 public:
798     static const NeighborhoodSearchTestData& get()
799     {
800         static RandomTriclinicFullPBCData singleton;
801         return singleton.data_;
802     }
803
804     RandomTriclinicFullPBCData() : data_(12345, 1.0)
805     {
806         data_.box_[XX][XX] = 5.0;
807         data_.box_[YY][XX] = 2.5;
808         data_.box_[YY][YY] = 2.5 * std::sqrt(3.0);
809         data_.box_[ZZ][XX] = 2.5;
810         data_.box_[ZZ][YY] = 2.5 * std::sqrt(1.0 / 3.0);
811         data_.box_[ZZ][ZZ] = 5.0 * std::sqrt(2.0 / 3.0);
812         // TODO: Consider whether manually picking some positions would give better
813         // test coverage.
814         data_.generateRandomRefPositions(1000);
815         data_.generateRandomTestPositions(100);
816         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
817         data_.computeReferences(&data_.pbc_);
818     }
819
820 private:
821     NeighborhoodSearchTestData data_;
822 };
823
824 class RandomBox2DPBCData
825 {
826 public:
827     static const NeighborhoodSearchTestData& get()
828     {
829         static RandomBox2DPBCData singleton;
830         return singleton.data_;
831     }
832
833     RandomBox2DPBCData() : data_(12345, 1.0)
834     {
835         data_.box_[XX][XX] = 10.0;
836         data_.box_[YY][YY] = 7.0;
837         data_.box_[ZZ][ZZ] = 5.0;
838         // TODO: Consider whether manually picking some positions would give better
839         // test coverage.
840         data_.generateRandomRefPositions(1000);
841         data_.generateRandomTestPositions(100);
842         set_pbc(&data_.pbc_, PbcType::XY, data_.box_);
843         data_.computeReferences(&data_.pbc_);
844     }
845
846 private:
847     NeighborhoodSearchTestData data_;
848 };
849
850 class RandomBoxNoPBCData
851 {
852 public:
853     static const NeighborhoodSearchTestData& get()
854     {
855         static RandomBoxNoPBCData singleton;
856         return singleton.data_;
857     }
858
859     RandomBoxNoPBCData() : data_(12345, 1.0)
860     {
861         data_.box_[XX][XX] = 10.0;
862         data_.box_[YY][YY] = 5.0;
863         data_.box_[ZZ][ZZ] = 7.0;
864         // TODO: Consider whether manually picking some positions would give better
865         // test coverage.
866         data_.generateRandomRefPositions(1000);
867         data_.generateRandomTestPositions(100);
868         set_pbc(&data_.pbc_, PbcType::No, data_.box_);
869         data_.computeReferences(nullptr);
870     }
871
872 private:
873     NeighborhoodSearchTestData data_;
874 };
875
876 /********************************************************************
877  * Actual tests
878  */
879
880 TEST_F(NeighborhoodSearchTest, SimpleSearch)
881 {
882     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
883
884     nb_.setCutoff(data.cutoff_);
885     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
886     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
887     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
888
889     testIsWithin(&search, data);
890     testMinimumDistance(&search, data);
891     testNearestPoint(&search, data);
892     testPairSearch(&search, data);
893
894     search.reset();
895     testPairSearchIndexed(&nb_, data, 123);
896 }
897
898 TEST_F(NeighborhoodSearchTest, SimpleSearchXY)
899 {
900     const NeighborhoodSearchTestData& data = RandomBoxXYFullPBCData::get();
901
902     nb_.setCutoff(data.cutoff_);
903     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
904     nb_.setXYMode(true);
905     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
906     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
907
908     testIsWithin(&search, data);
909     testMinimumDistance(&search, data);
910     testNearestPoint(&search, data);
911     testPairSearch(&search, data);
912 }
913
914 TEST_F(NeighborhoodSearchTest, GridSearchBox)
915 {
916     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
917
918     nb_.setCutoff(data.cutoff_);
919     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
920     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
921     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
922
923     testIsWithin(&search, data);
924     testMinimumDistance(&search, data);
925     testNearestPoint(&search, data);
926     testPairSearch(&search, data);
927
928     search.reset();
929     testPairSearchIndexed(&nb_, data, 456);
930 }
931
932 TEST_F(NeighborhoodSearchTest, GridSearchTriclinic)
933 {
934     const NeighborhoodSearchTestData& data = RandomTriclinicFullPBCData::get();
935
936     nb_.setCutoff(data.cutoff_);
937     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
938     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
939     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
940
941     testPairSearch(&search, data);
942 }
943
944 TEST_F(NeighborhoodSearchTest, GridSearch2DPBC)
945 {
946     const NeighborhoodSearchTestData& data = RandomBox2DPBCData::get();
947
948     nb_.setCutoff(data.cutoff_);
949     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
950     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
951     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
952
953     testIsWithin(&search, data);
954     testMinimumDistance(&search, data);
955     testNearestPoint(&search, data);
956     testPairSearch(&search, data);
957 }
958
959 TEST_F(NeighborhoodSearchTest, GridSearchNoPBC)
960 {
961     const NeighborhoodSearchTestData& data = RandomBoxNoPBCData::get();
962
963     nb_.setCutoff(data.cutoff_);
964     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
965     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
966     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
967
968     testPairSearch(&search, data);
969 }
970
971 TEST_F(NeighborhoodSearchTest, GridSearchXYBox)
972 {
973     const NeighborhoodSearchTestData& data = RandomBoxXYFullPBCData::get();
974
975     nb_.setCutoff(data.cutoff_);
976     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
977     nb_.setXYMode(true);
978     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
979     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
980
981     testIsWithin(&search, data);
982     testMinimumDistance(&search, data);
983     testNearestPoint(&search, data);
984     testPairSearch(&search, data);
985 }
986
987 TEST_F(NeighborhoodSearchTest, SimpleSelfPairsSearch)
988 {
989     const NeighborhoodSearchTestData& data = TrivialSelfPairsTestData::get();
990
991     nb_.setCutoff(data.cutoff_);
992     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
993     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
994     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
995
996     testPairSearchFull(&search, data, data.testPositions(), nullptr, {}, {}, true);
997 }
998
999 TEST_F(NeighborhoodSearchTest, GridSelfPairsSearch)
1000 {
1001     const NeighborhoodSearchTestData& data = RandomBoxSelfPairsData::get();
1002
1003     nb_.setCutoff(data.cutoff_);
1004     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
1005     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
1006     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
1007
1008     testPairSearchFull(&search, data, data.testPositions(), nullptr, {}, {}, true);
1009 }
1010
1011 TEST_F(NeighborhoodSearchTest, HandlesConcurrentSearches)
1012 {
1013     const NeighborhoodSearchTestData& data = TrivialTestData::get();
1014
1015     nb_.setCutoff(data.cutoff_);
1016     gmx::AnalysisNeighborhoodSearch search1 = nb_.initSearch(&data.pbc_, data.refPositions());
1017     gmx::AnalysisNeighborhoodSearch search2 = nb_.initSearch(&data.pbc_, data.refPositions());
1018
1019     // These checks are fragile, and unfortunately depend on the random
1020     // engine used to create the test positions. There is no particular reason
1021     // why exactly particles 0 & 2 should have neighbors, but in this case they do.
1022     gmx::AnalysisNeighborhoodPairSearch pairSearch1 = search1.startPairSearch(data.testPosition(0));
1023     gmx::AnalysisNeighborhoodPairSearch pairSearch2 = search1.startPairSearch(data.testPosition(2));
1024
1025     testPairSearch(&search2, data);
1026
1027     gmx::AnalysisNeighborhoodPair pair;
1028     ASSERT_TRUE(pairSearch1.findNextPair(&pair))
1029             << "Test data did not contain any pairs for position 0 (problem in the test).";
1030     EXPECT_EQ(0, pair.testIndex());
1031     {
1032         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), std::sqrt(pair.distance2()));
1033         EXPECT_TRUE(data.containsPair(0, searchPair));
1034     }
1035
1036     ASSERT_TRUE(pairSearch2.findNextPair(&pair))
1037             << "Test data did not contain any pairs for position 2 (problem in the test).";
1038     EXPECT_EQ(2, pair.testIndex());
1039     {
1040         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), std::sqrt(pair.distance2()));
1041         EXPECT_TRUE(data.containsPair(2, searchPair));
1042     }
1043 }
1044
1045 TEST_F(NeighborhoodSearchTest, HandlesNoPBC)
1046 {
1047     const NeighborhoodSearchTestData& data = TrivialNoPBCTestData::get();
1048
1049     nb_.setCutoff(data.cutoff_);
1050     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
1051     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
1052
1053     testIsWithin(&search, data);
1054     testMinimumDistance(&search, data);
1055     testNearestPoint(&search, data);
1056     testPairSearch(&search, data);
1057 }
1058
1059 TEST_F(NeighborhoodSearchTest, HandlesNullPBC)
1060 {
1061     const NeighborhoodSearchTestData& data = TrivialNoPBCTestData::get();
1062
1063     nb_.setCutoff(data.cutoff_);
1064     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(nullptr, data.refPositions());
1065     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
1066
1067     testIsWithin(&search, data);
1068     testMinimumDistance(&search, data);
1069     testNearestPoint(&search, data);
1070     testPairSearch(&search, data);
1071 }
1072
1073 TEST_F(NeighborhoodSearchTest, HandlesSkippingPairs)
1074 {
1075     const NeighborhoodSearchTestData& data = TrivialTestData::get();
1076
1077     nb_.setCutoff(data.cutoff_);
1078     gmx::AnalysisNeighborhoodSearch     search = nb_.initSearch(&data.pbc_, data.refPositions());
1079     gmx::AnalysisNeighborhoodPairSearch pairSearch = search.startPairSearch(data.testPositions());
1080     gmx::AnalysisNeighborhoodPair       pair;
1081     // TODO: This test needs to be adjusted if the grid search gets optimized
1082     // to loop over the test positions in cell order (first, the ordering
1083     // assumption here breaks, and second, it then needs to be tested
1084     // separately for simple and grid searches).
1085     int currentIndex = 0;
1086     while (pairSearch.findNextPair(&pair))
1087     {
1088         while (currentIndex < pair.testIndex())
1089         {
1090             ++currentIndex;
1091         }
1092         EXPECT_EQ(currentIndex, pair.testIndex());
1093         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), std::sqrt(pair.distance2()));
1094         EXPECT_TRUE(data.containsPair(currentIndex, searchPair));
1095         pairSearch.skipRemainingPairsForTestPosition();
1096         ++currentIndex;
1097     }
1098 }
1099
1100 TEST_F(NeighborhoodSearchTest, SimpleSearchExclusions)
1101 {
1102     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
1103
1104     ExclusionsHelper helper(data.refPosCount_, data.testPositions_.size());
1105     helper.generateExclusions();
1106
1107     nb_.setCutoff(data.cutoff_);
1108     nb_.setTopologyExclusions(helper.exclusions());
1109     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
1110     gmx::AnalysisNeighborhoodSearch search =
1111             nb_.initSearch(&data.pbc_, data.refPositions().exclusionIds(helper.refPosIds()));
1112     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
1113
1114     testPairSearchFull(&search, data, data.testPositions().exclusionIds(helper.testPosIds()),
1115                        helper.exclusions(), {}, {}, false);
1116 }
1117
1118 TEST_F(NeighborhoodSearchTest, GridSearchExclusions)
1119 {
1120     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
1121
1122     ExclusionsHelper helper(data.refPosCount_, data.testPositions_.size());
1123     helper.generateExclusions();
1124
1125     nb_.setCutoff(data.cutoff_);
1126     nb_.setTopologyExclusions(helper.exclusions());
1127     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
1128     gmx::AnalysisNeighborhoodSearch search =
1129             nb_.initSearch(&data.pbc_, data.refPositions().exclusionIds(helper.refPosIds()));
1130     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
1131
1132     testPairSearchFull(&search, data, data.testPositions().exclusionIds(helper.testPosIds()),
1133                        helper.exclusions(), {}, {}, false);
1134 }
1135
1136 } // namespace