Apply clang-format-11
[alexxy/gromacs.git] / src / gromacs / selection / tests / nbsearch.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2013,2014,2015,2016,2017 by the GROMACS development team.
5  * Copyright (c) 2018,2019,2020,2021, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36 /*! \internal \file
37  * \brief
38  * Tests selection neighborhood searching.
39  *
40  * \todo
41  * Increase coverage of these tests for different corner cases: other PBC cases
42  * than full 3D, large cutoffs (larger than half the box size), etc.
43  * At least some of these probably don't work correctly.
44  *
45  * \author Teemu Murtola <teemu.murtola@gmail.com>
46  * \ingroup module_selection
47  */
48 #include "gmxpre.h"
49
50 #include "gromacs/selection/nbsearch.h"
51
52 #include <cmath>
53
54 #include <algorithm>
55 #include <limits>
56 #include <map>
57 #include <numeric>
58 #include <vector>
59
60 #include <gtest/gtest.h>
61
62 #include "gromacs/math/functions.h"
63 #include "gromacs/math/vec.h"
64 #include "gromacs/pbcutil/pbc.h"
65 #include "gromacs/random/threefry.h"
66 #include "gromacs/random/uniformrealdistribution.h"
67 #include "gromacs/topology/block.h"
68 #include "gromacs/utility/listoflists.h"
69 #include "gromacs/utility/smalloc.h"
70 #include "gromacs/utility/stringutil.h"
71
72 #include "testutils/testasserts.h"
73
74
75 namespace
76 {
77
78 /********************************************************************
79  * NeighborhoodSearchTestData
80  */
81
82 class NeighborhoodSearchTestData
83 {
84 public:
85     struct RefPair
86     {
87         RefPair(int refIndex, real distance) :
88             refIndex(refIndex), distance(distance), bFound(false), bExcluded(false), bIndexed(true)
89         {
90         }
91
92         bool operator<(const RefPair& other) const { return refIndex < other.refIndex; }
93
94         int  refIndex;
95         real distance;
96         // The variables below are state variables that are only used
97         // during the actual testing after creating a copy of the reference
98         // pair list, not as part of the reference data.
99         // Simpler to have just a single structure for both purposes.
100         bool bFound;
101         bool bExcluded;
102         bool bIndexed;
103     };
104
105     struct TestPosition
106     {
107         explicit TestPosition(const rvec x) : refMinDist(0.0), refNearestPoint(-1)
108         {
109             copy_rvec(x, this->x);
110         }
111
112         rvec                 x;
113         real                 refMinDist;
114         int                  refNearestPoint;
115         std::vector<RefPair> refPairs;
116     };
117
118     typedef std::vector<TestPosition> TestPositionList;
119
120     NeighborhoodSearchTestData(uint64_t seed, real cutoff);
121
122     gmx::AnalysisNeighborhoodPositions refPositions() const
123     {
124         return gmx::AnalysisNeighborhoodPositions(refPos_);
125     }
126     gmx::AnalysisNeighborhoodPositions testPositions() const
127     {
128         if (testPos_.empty())
129         {
130             testPos_.reserve(testPositions_.size());
131             for (size_t i = 0; i < testPositions_.size(); ++i)
132             {
133                 testPos_.emplace_back(testPositions_[i].x);
134             }
135         }
136         return gmx::AnalysisNeighborhoodPositions(testPos_);
137     }
138     gmx::AnalysisNeighborhoodPositions testPosition(int index) const
139     {
140         return testPositions().selectSingleFromArray(index);
141     }
142
143     void addTestPosition(const rvec x)
144     {
145         GMX_RELEASE_ASSERT(testPos_.empty(), "Cannot add positions after testPositions() call");
146         testPositions_.emplace_back(x);
147     }
148     gmx::RVec               generateRandomPosition();
149     static std::vector<int> generateIndex(int count, uint64_t seed);
150     void                    generateRandomRefPositions(int count);
151     void                    generateRandomTestPositions(int count);
152     void                    useRefPositionsAsTestPositions();
153     void                    computeReferences(t_pbc* pbc) { computeReferencesInternal(pbc, false); }
154     void computeReferencesXY(t_pbc* pbc) { computeReferencesInternal(pbc, true); }
155
156     bool containsPair(int testIndex, const RefPair& pair) const
157     {
158         const std::vector<RefPair>&          refPairs = testPositions_[testIndex].refPairs;
159         std::vector<RefPair>::const_iterator foundRefPair =
160                 std::lower_bound(refPairs.begin(), refPairs.end(), pair);
161         return !(foundRefPair == refPairs.end() || foundRefPair->refIndex != pair.refIndex);
162     }
163
164     // Return a tolerance that accounts for the magnitudes of the coordinates
165     // when doing subtractions, so that we set the ULP tolerance relative to the
166     // coordinate values rather than their difference.
167     // i.e., 10.0-9.9999999 will achieve a few ULP accuracy relative
168     // to 10.0, but a much larger error relative to the difference.
169     gmx::test::FloatingPointTolerance relativeTolerance() const
170     {
171         real magnitude = std::max(box_[XX][XX], std::max(box_[YY][YY], box_[ZZ][ZZ]));
172         return gmx::test::relativeToleranceAsUlp(magnitude, 4);
173     }
174
175     gmx::DefaultRandomEngine rng_;
176     real                     cutoff_;
177     matrix                   box_;
178     t_pbc                    pbc_;
179     int                      refPosCount_;
180     std::vector<gmx::RVec>   refPos_;
181     TestPositionList         testPositions_;
182
183 private:
184     void computeReferencesInternal(t_pbc* pbc, bool bXY);
185
186     mutable std::vector<gmx::RVec> testPos_;
187 };
188
189 //! Shorthand for a collection of reference pairs.
190 typedef std::vector<NeighborhoodSearchTestData::RefPair> RefPairList;
191
192 NeighborhoodSearchTestData::NeighborhoodSearchTestData(uint64_t seed, real cutoff) :
193     rng_(seed), cutoff_(cutoff), refPosCount_(0)
194 {
195     clear_mat(box_);
196     set_pbc(&pbc_, PbcType::No, box_);
197 }
198
199 gmx::RVec NeighborhoodSearchTestData::generateRandomPosition()
200 {
201     gmx::UniformRealDistribution<real> dist;
202     rvec                               fx, x;
203     fx[XX] = dist(rng_);
204     fx[YY] = dist(rng_);
205     fx[ZZ] = dist(rng_);
206     mvmul(box_, fx, x);
207     // Add a small displacement to allow positions outside the box
208     x[XX] += 0.2 * dist(rng_) - 0.1;
209     x[YY] += 0.2 * dist(rng_) - 0.1;
210     x[ZZ] += 0.2 * dist(rng_) - 0.1;
211     return x;
212 }
213
214 std::vector<int> NeighborhoodSearchTestData::generateIndex(int count, uint64_t seed)
215 {
216     gmx::DefaultRandomEngine           rngIndex(seed);
217     gmx::UniformRealDistribution<real> dist;
218     std::vector<int>                   result;
219
220     for (int i = 0; i < count; ++i)
221     {
222         if (dist(rngIndex) > 0.5)
223         {
224             result.push_back(i);
225         }
226     }
227     return result;
228 }
229
230 void NeighborhoodSearchTestData::generateRandomRefPositions(int count)
231 {
232     refPosCount_ = count;
233     refPos_.reserve(count);
234     for (int i = 0; i < count; ++i)
235     {
236         refPos_.push_back(generateRandomPosition());
237     }
238 }
239
240 void NeighborhoodSearchTestData::generateRandomTestPositions(int count)
241 {
242     testPositions_.reserve(count);
243     for (int i = 0; i < count; ++i)
244     {
245         addTestPosition(generateRandomPosition());
246     }
247 }
248
249 void NeighborhoodSearchTestData::useRefPositionsAsTestPositions()
250 {
251     testPositions_.reserve(refPosCount_);
252     for (const auto& refPos : refPos_)
253     {
254         addTestPosition(refPos);
255     }
256 }
257
258 void NeighborhoodSearchTestData::computeReferencesInternal(t_pbc* pbc, bool bXY)
259 {
260     real cutoff = cutoff_;
261     if (cutoff <= 0)
262     {
263         cutoff = std::numeric_limits<real>::max();
264     }
265     for (TestPosition& testPos : testPositions_)
266     {
267         testPos.refMinDist      = cutoff;
268         testPos.refNearestPoint = -1;
269         testPos.refPairs.clear();
270         for (int j = 0; j < refPosCount_; ++j)
271         {
272             rvec dx;
273             if (pbc != nullptr)
274             {
275                 pbc_dx(pbc, testPos.x, refPos_[j], dx);
276             }
277             else
278             {
279                 rvec_sub(testPos.x, refPos_[j], dx);
280             }
281             // TODO: This may not work intuitively for 2D with the third box
282             // vector not parallel to the Z axis, but neither does the actual
283             // neighborhood search.
284             const real dist = !bXY ? norm(dx) : std::hypot(dx[XX], dx[YY]);
285             if (dist < testPos.refMinDist)
286             {
287                 testPos.refMinDist      = dist;
288                 testPos.refNearestPoint = j;
289             }
290             if (dist > 0 && dist <= cutoff)
291             {
292                 RefPair pair(j, dist);
293                 GMX_RELEASE_ASSERT(testPos.refPairs.empty() || testPos.refPairs.back() < pair,
294                                    "Reference pairs should be generated in sorted order");
295                 testPos.refPairs.push_back(pair);
296             }
297         }
298     }
299 }
300
301 /********************************************************************
302  * ExclusionsHelper
303  */
304
305 class ExclusionsHelper
306 {
307 public:
308     static void markExcludedPairs(RefPairList* refPairs, int testIndex, const gmx::ListOfLists<int>* excls);
309
310     ExclusionsHelper(int refPosCount, int testPosCount);
311
312     void generateExclusions();
313
314     const gmx::ListOfLists<int>* exclusions() const { return &excls_; }
315
316     gmx::ArrayRef<const int> refPosIds() const
317     {
318         return gmx::makeArrayRef(exclusionIds_).subArray(0, refPosCount_);
319     }
320     gmx::ArrayRef<const int> testPosIds() const
321     {
322         return gmx::makeArrayRef(exclusionIds_).subArray(0, testPosCount_);
323     }
324
325 private:
326     int                   refPosCount_;
327     int                   testPosCount_;
328     std::vector<int>      exclusionIds_;
329     gmx::ListOfLists<int> excls_;
330 };
331
332 // static
333 void ExclusionsHelper::markExcludedPairs(RefPairList* refPairs, int testIndex, const gmx::ListOfLists<int>* excls)
334 {
335     int count = 0;
336     for (const int excludedIndex : (*excls)[testIndex])
337     {
338         NeighborhoodSearchTestData::RefPair searchPair(excludedIndex, 0.0);
339         RefPairList::iterator               excludedRefPair =
340                 std::lower_bound(refPairs->begin(), refPairs->end(), searchPair);
341         if (excludedRefPair != refPairs->end() && excludedRefPair->refIndex == excludedIndex)
342         {
343             excludedRefPair->bFound    = true;
344             excludedRefPair->bExcluded = true;
345             ++count;
346         }
347     }
348 }
349
350 ExclusionsHelper::ExclusionsHelper(int refPosCount, int testPosCount) :
351     refPosCount_(refPosCount), testPosCount_(testPosCount)
352 {
353     // Generate an array of 0, 1, 2, ...
354     // TODO: Make the tests work also with non-trivial exclusion IDs,
355     // and test that.
356     exclusionIds_.resize(std::max(refPosCount, testPosCount), 1);
357     exclusionIds_[0] = 0;
358     std::partial_sum(exclusionIds_.begin(), exclusionIds_.end(), exclusionIds_.begin());
359 }
360
361 void ExclusionsHelper::generateExclusions()
362 {
363     // TODO: Consider a better set of test data, where the density of the
364     // particles would be higher, or where the exclusions would not be random,
365     // to make a higher percentage of the exclusions to actually be within the
366     // cutoff.
367     for (int i = 0; i < testPosCount_; ++i)
368     {
369         excls_.pushBackListOfSize(20);
370         gmx::ArrayRef<int> exclusionsForAtom = excls_.back();
371         for (int j = 0; j < 20; ++j)
372         {
373             exclusionsForAtom[j] = i + j * 3;
374         }
375     }
376 }
377
378 /********************************************************************
379  * NeighborhoodSearchTest
380  */
381
382 class NeighborhoodSearchTest : public ::testing::Test
383 {
384 public:
385     static void testIsWithin(gmx::AnalysisNeighborhoodSearch*  search,
386                              const NeighborhoodSearchTestData& data);
387     static void testMinimumDistance(gmx::AnalysisNeighborhoodSearch*  search,
388                                     const NeighborhoodSearchTestData& data);
389     static void testNearestPoint(gmx::AnalysisNeighborhoodSearch*  search,
390                                  const NeighborhoodSearchTestData& data);
391     static void testPairSearch(gmx::AnalysisNeighborhoodSearch*  search,
392                                const NeighborhoodSearchTestData& data);
393     static void testPairSearchIndexed(gmx::AnalysisNeighborhood*        nb,
394                                       const NeighborhoodSearchTestData& data,
395                                       uint64_t                          seed);
396     static void testPairSearchFull(gmx::AnalysisNeighborhoodSearch*          search,
397                                    const NeighborhoodSearchTestData&         data,
398                                    const gmx::AnalysisNeighborhoodPositions& pos,
399                                    const gmx::ListOfLists<int>*              excls,
400                                    const gmx::ArrayRef<const int>&           refIndices,
401                                    const gmx::ArrayRef<const int>&           testIndices,
402                                    bool                                      selfPairs);
403
404     gmx::AnalysisNeighborhood nb_;
405 };
406
407 void NeighborhoodSearchTest::testIsWithin(gmx::AnalysisNeighborhoodSearch*  search,
408                                           const NeighborhoodSearchTestData& data)
409 {
410     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
411     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
412     {
413         const bool bWithin = (i->refMinDist <= data.cutoff_);
414         EXPECT_EQ(bWithin, search->isWithin(i->x)) << "Distance is " << i->refMinDist;
415     }
416 }
417
418 void NeighborhoodSearchTest::testMinimumDistance(gmx::AnalysisNeighborhoodSearch*  search,
419                                                  const NeighborhoodSearchTestData& data)
420 {
421     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
422
423     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
424     {
425         const real refDist = i->refMinDist;
426         EXPECT_REAL_EQ_TOL(refDist, search->minimumDistance(i->x), data.relativeTolerance());
427     }
428 }
429
430 void NeighborhoodSearchTest::testNearestPoint(gmx::AnalysisNeighborhoodSearch*  search,
431                                               const NeighborhoodSearchTestData& data)
432 {
433     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
434     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
435     {
436         const gmx::AnalysisNeighborhoodPair pair = search->nearestPoint(i->x);
437         if (pair.isValid())
438         {
439             EXPECT_EQ(i->refNearestPoint, pair.refIndex());
440             EXPECT_EQ(0, pair.testIndex());
441             EXPECT_REAL_EQ_TOL(i->refMinDist, std::sqrt(pair.distance2()), data.relativeTolerance());
442         }
443         else
444         {
445             EXPECT_EQ(i->refNearestPoint, -1);
446         }
447     }
448 }
449
450 //! Helper function for formatting test failure messages.
451 std::string formatVector(const rvec x)
452 {
453     return gmx::formatString("[%.3f, %.3f, %.3f]", x[XX], x[YY], x[ZZ]);
454 }
455
456 /*! \brief
457  * Helper function to check that all expected pairs were found.
458  */
459 void checkAllPairsFound(const RefPairList&            refPairs,
460                         const std::vector<gmx::RVec>& refPos,
461                         int                           testPosIndex,
462                         const rvec                    testPos)
463 {
464     // This could be elegantly expressed with Google Mock matchers, but that
465     // has a significant effect on the runtime of the tests...
466     int                         count = 0;
467     RefPairList::const_iterator first;
468     for (RefPairList::const_iterator i = refPairs.begin(); i != refPairs.end(); ++i)
469     {
470         if (!i->bFound)
471         {
472             ++count;
473             first = i;
474         }
475     }
476     if (count > 0)
477     {
478         ADD_FAILURE() << "Some pairs (" << count << "/" << refPairs.size() << ") "
479                       << "within the cutoff were not found. First pair:\n"
480                       << " Ref: " << first->refIndex << " at "
481                       << formatVector(refPos[first->refIndex]) << "\n"
482                       << "Test: " << testPosIndex << " at " << formatVector(testPos) << "\n"
483                       << "Dist: " << first->distance;
484     }
485 }
486
487 void NeighborhoodSearchTest::testPairSearch(gmx::AnalysisNeighborhoodSearch*  search,
488                                             const NeighborhoodSearchTestData& data)
489 {
490     testPairSearchFull(search, data, data.testPositions(), nullptr, {}, {}, false);
491 }
492
493 void NeighborhoodSearchTest::testPairSearchIndexed(gmx::AnalysisNeighborhood*        nb,
494                                                    const NeighborhoodSearchTestData& data,
495                                                    uint64_t                          seed)
496 {
497     std::vector<int> refIndices(data.generateIndex(data.refPos_.size(), seed++));
498     std::vector<int> testIndices(data.generateIndex(data.testPositions_.size(), seed++));
499     gmx::AnalysisNeighborhoodSearch search =
500             nb->initSearch(&data.pbc_, data.refPositions().indexed(refIndices));
501     testPairSearchFull(&search, data, data.testPositions(), nullptr, refIndices, testIndices, false);
502 }
503
504 void NeighborhoodSearchTest::testPairSearchFull(gmx::AnalysisNeighborhoodSearch*          search,
505                                                 const NeighborhoodSearchTestData&         data,
506                                                 const gmx::AnalysisNeighborhoodPositions& pos,
507                                                 const gmx::ListOfLists<int>*              excls,
508                                                 const gmx::ArrayRef<const int>& refIndices,
509                                                 const gmx::ArrayRef<const int>& testIndices,
510                                                 bool                            selfPairs)
511 {
512     std::map<int, RefPairList> refPairs;
513     // TODO: Some parts of this code do not work properly if pos does not
514     // initially contain all the test positions.
515     if (testIndices.empty())
516     {
517         for (size_t i = 0; i < data.testPositions_.size(); ++i)
518         {
519             refPairs[i] = data.testPositions_[i].refPairs;
520         }
521     }
522     else
523     {
524         for (int index : testIndices)
525         {
526             refPairs[index] = data.testPositions_[index].refPairs;
527         }
528     }
529     if (excls != nullptr)
530     {
531         GMX_RELEASE_ASSERT(!selfPairs, "Self-pairs testing not implemented with exclusions");
532         for (auto& entry : refPairs)
533         {
534             const int testIndex = entry.first;
535             ExclusionsHelper::markExcludedPairs(&entry.second, testIndex, excls);
536         }
537     }
538     if (!refIndices.empty())
539     {
540         GMX_RELEASE_ASSERT(!selfPairs, "Self-pairs testing not implemented with indexing");
541         for (auto& entry : refPairs)
542         {
543             for (auto& refPair : entry.second)
544             {
545                 refPair.bIndexed = false;
546             }
547             for (int index : refIndices)
548             {
549                 NeighborhoodSearchTestData::RefPair searchPair(index, 0.0);
550                 auto refPair = std::lower_bound(entry.second.begin(), entry.second.end(), searchPair);
551                 if (refPair != entry.second.end() && refPair->refIndex == index)
552                 {
553                     refPair->bIndexed = true;
554                 }
555             }
556             for (auto& refPair : entry.second)
557             {
558                 if (!refPair.bIndexed)
559                 {
560                     refPair.bFound = true;
561                 }
562             }
563         }
564     }
565
566     gmx::AnalysisNeighborhoodPositions posCopy(pos);
567     if (!testIndices.empty())
568     {
569         posCopy.indexed(testIndices);
570     }
571     gmx::AnalysisNeighborhoodPairSearch pairSearch =
572             selfPairs ? search->startSelfPairSearch() : search->startPairSearch(posCopy);
573     gmx::AnalysisNeighborhoodPair pair;
574     while (pairSearch.findNextPair(&pair))
575     {
576         const int testIndex = (testIndices.empty() ? pair.testIndex() : testIndices[pair.testIndex()]);
577         const int refIndex = (refIndices.empty() ? pair.refIndex() : refIndices[pair.refIndex()]);
578
579         if (refPairs.count(testIndex) == 0)
580         {
581             ADD_FAILURE() << "Expected: No pairs are returned for test position " << testIndex << ".\n"
582                           << "  Actual: Pair with ref " << refIndex << " is returned.";
583             continue;
584         }
585         NeighborhoodSearchTestData::RefPair searchPair(refIndex, std::sqrt(pair.distance2()));
586         const auto                          foundRefPair =
587                 std::lower_bound(refPairs[testIndex].begin(), refPairs[testIndex].end(), searchPair);
588         if (foundRefPair == refPairs[testIndex].end() || foundRefPair->refIndex != refIndex)
589         {
590             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
591                           << ") is not within the cutoff.\n"
592                           << "  Actual: It is returned.";
593         }
594         else if (foundRefPair->bExcluded)
595         {
596             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
597                           << ") is excluded from the search.\n"
598                           << "  Actual: It is returned.";
599         }
600         else if (!foundRefPair->bIndexed)
601         {
602             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
603                           << ") is not part of the indexed set.\n"
604                           << "  Actual: It is returned.";
605         }
606         else if (foundRefPair->bFound)
607         {
608             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
609                           << ") is returned only once.\n"
610                           << "  Actual: It is returned multiple times.";
611             return;
612         }
613         else
614         {
615             foundRefPair->bFound = true;
616
617             EXPECT_REAL_EQ_TOL(foundRefPair->distance, searchPair.distance, data.relativeTolerance())
618                     << "Distance computed by the neighborhood search does not match.";
619             if (selfPairs)
620             {
621                 searchPair              = NeighborhoodSearchTestData::RefPair(testIndex, 0.0);
622                 const auto otherRefPair = std::lower_bound(
623                         refPairs[refIndex].begin(), refPairs[refIndex].end(), searchPair);
624                 GMX_RELEASE_ASSERT(otherRefPair != refPairs[refIndex].end(),
625                                    "Precomputed reference data is not symmetric");
626                 otherRefPair->bFound = true;
627             }
628         }
629     }
630
631     for (auto& entry : refPairs)
632     {
633         const int testIndex = entry.first;
634         checkAllPairsFound(entry.second, data.refPos_, testIndex, data.testPositions_[testIndex].x);
635     }
636 }
637
638 /********************************************************************
639  * Test data generation
640  */
641
642 class TrivialTestData
643 {
644 public:
645     static const NeighborhoodSearchTestData& get()
646     {
647         static TrivialTestData singleton;
648         return singleton.data_;
649     }
650
651     TrivialTestData() : data_(12345, 1.0)
652     {
653         // Make the box so small we are virtually guaranteed to have
654         // several neighbors for the five test positions
655         data_.box_[XX][XX] = 3.0;
656         data_.box_[YY][YY] = 3.0;
657         data_.box_[ZZ][ZZ] = 3.0;
658         data_.generateRandomRefPositions(10);
659         data_.generateRandomTestPositions(5);
660         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
661         data_.computeReferences(&data_.pbc_);
662     }
663
664 private:
665     NeighborhoodSearchTestData data_;
666 };
667
668 class TrivialSelfPairsTestData
669 {
670 public:
671     static const NeighborhoodSearchTestData& get()
672     {
673         static TrivialSelfPairsTestData singleton;
674         return singleton.data_;
675     }
676
677     TrivialSelfPairsTestData() : data_(12345, 1.0)
678     {
679         data_.box_[XX][XX] = 3.0;
680         data_.box_[YY][YY] = 3.0;
681         data_.box_[ZZ][ZZ] = 3.0;
682         data_.generateRandomRefPositions(20);
683         data_.useRefPositionsAsTestPositions();
684         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
685         data_.computeReferences(&data_.pbc_);
686     }
687
688 private:
689     NeighborhoodSearchTestData data_;
690 };
691
692 class TrivialNoPBCTestData
693 {
694 public:
695     static const NeighborhoodSearchTestData& get()
696     {
697         static TrivialNoPBCTestData singleton;
698         return singleton.data_;
699     }
700
701     TrivialNoPBCTestData() : data_(12345, 1.0)
702     {
703         data_.generateRandomRefPositions(10);
704         data_.generateRandomTestPositions(5);
705         data_.computeReferences(nullptr);
706     }
707
708 private:
709     NeighborhoodSearchTestData data_;
710 };
711
712 class RandomBoxFullPBCData
713 {
714 public:
715     static const NeighborhoodSearchTestData& get()
716     {
717         static RandomBoxFullPBCData singleton;
718         return singleton.data_;
719     }
720
721     RandomBoxFullPBCData() : data_(12345, 1.0)
722     {
723         data_.box_[XX][XX] = 10.0;
724         data_.box_[YY][YY] = 5.0;
725         data_.box_[ZZ][ZZ] = 7.0;
726         // TODO: Consider whether manually picking some positions would give better
727         // test coverage.
728         data_.generateRandomRefPositions(1000);
729         data_.generateRandomTestPositions(100);
730         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
731         data_.computeReferences(&data_.pbc_);
732     }
733
734 private:
735     NeighborhoodSearchTestData data_;
736 };
737
738 class RandomBoxSelfPairsData
739 {
740 public:
741     static const NeighborhoodSearchTestData& get()
742     {
743         static RandomBoxSelfPairsData singleton;
744         return singleton.data_;
745     }
746
747     RandomBoxSelfPairsData() : data_(12345, 1.0)
748     {
749         data_.box_[XX][XX] = 10.0;
750         data_.box_[YY][YY] = 5.0;
751         data_.box_[ZZ][ZZ] = 7.0;
752         data_.generateRandomRefPositions(1000);
753         data_.useRefPositionsAsTestPositions();
754         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
755         data_.computeReferences(&data_.pbc_);
756     }
757
758 private:
759     NeighborhoodSearchTestData data_;
760 };
761
762 class RandomBoxXYFullPBCData
763 {
764 public:
765     static const NeighborhoodSearchTestData& get()
766     {
767         static RandomBoxXYFullPBCData singleton;
768         return singleton.data_;
769     }
770
771     RandomBoxXYFullPBCData() : data_(54321, 1.0)
772     {
773         data_.box_[XX][XX] = 10.0;
774         data_.box_[YY][YY] = 5.0;
775         data_.box_[ZZ][ZZ] = 7.0;
776         // TODO: Consider whether manually picking some positions would give better
777         // test coverage.
778         data_.generateRandomRefPositions(1000);
779         data_.generateRandomTestPositions(100);
780         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
781         data_.computeReferencesXY(&data_.pbc_);
782     }
783
784 private:
785     NeighborhoodSearchTestData data_;
786 };
787
788 class RandomTriclinicFullPBCData
789 {
790 public:
791     static const NeighborhoodSearchTestData& get()
792     {
793         static RandomTriclinicFullPBCData singleton;
794         return singleton.data_;
795     }
796
797     RandomTriclinicFullPBCData() : data_(12345, 1.0)
798     {
799         data_.box_[XX][XX] = 5.0;
800         data_.box_[YY][XX] = 2.5;
801         data_.box_[YY][YY] = 2.5 * std::sqrt(3.0);
802         data_.box_[ZZ][XX] = 2.5;
803         data_.box_[ZZ][YY] = 2.5 * std::sqrt(1.0 / 3.0);
804         data_.box_[ZZ][ZZ] = 5.0 * std::sqrt(2.0 / 3.0);
805         // TODO: Consider whether manually picking some positions would give better
806         // test coverage.
807         data_.generateRandomRefPositions(1000);
808         data_.generateRandomTestPositions(100);
809         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
810         data_.computeReferences(&data_.pbc_);
811     }
812
813 private:
814     NeighborhoodSearchTestData data_;
815 };
816
817 class RandomBox2DPBCData
818 {
819 public:
820     static const NeighborhoodSearchTestData& get()
821     {
822         static RandomBox2DPBCData singleton;
823         return singleton.data_;
824     }
825
826     RandomBox2DPBCData() : data_(12345, 1.0)
827     {
828         data_.box_[XX][XX] = 10.0;
829         data_.box_[YY][YY] = 7.0;
830         data_.box_[ZZ][ZZ] = 5.0;
831         // TODO: Consider whether manually picking some positions would give better
832         // test coverage.
833         data_.generateRandomRefPositions(1000);
834         data_.generateRandomTestPositions(100);
835         set_pbc(&data_.pbc_, PbcType::XY, data_.box_);
836         data_.computeReferences(&data_.pbc_);
837     }
838
839 private:
840     NeighborhoodSearchTestData data_;
841 };
842
843 class RandomBoxNoPBCData
844 {
845 public:
846     static const NeighborhoodSearchTestData& get()
847     {
848         static RandomBoxNoPBCData singleton;
849         return singleton.data_;
850     }
851
852     RandomBoxNoPBCData() : data_(12345, 1.0)
853     {
854         data_.box_[XX][XX] = 10.0;
855         data_.box_[YY][YY] = 5.0;
856         data_.box_[ZZ][ZZ] = 7.0;
857         // TODO: Consider whether manually picking some positions would give better
858         // test coverage.
859         data_.generateRandomRefPositions(1000);
860         data_.generateRandomTestPositions(100);
861         set_pbc(&data_.pbc_, PbcType::No, data_.box_);
862         data_.computeReferences(nullptr);
863     }
864
865 private:
866     NeighborhoodSearchTestData data_;
867 };
868
869 /********************************************************************
870  * Actual tests
871  */
872
873 TEST_F(NeighborhoodSearchTest, SimpleSearch)
874 {
875     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
876
877     nb_.setCutoff(data.cutoff_);
878     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
879     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
880     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
881
882     testIsWithin(&search, data);
883     testMinimumDistance(&search, data);
884     testNearestPoint(&search, data);
885     testPairSearch(&search, data);
886
887     search.reset();
888     testPairSearchIndexed(&nb_, data, 123);
889 }
890
891 TEST_F(NeighborhoodSearchTest, SimpleSearchXY)
892 {
893     const NeighborhoodSearchTestData& data = RandomBoxXYFullPBCData::get();
894
895     nb_.setCutoff(data.cutoff_);
896     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
897     nb_.setXYMode(true);
898     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
899     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
900
901     testIsWithin(&search, data);
902     testMinimumDistance(&search, data);
903     testNearestPoint(&search, data);
904     testPairSearch(&search, data);
905 }
906
907 TEST_F(NeighborhoodSearchTest, GridSearchBox)
908 {
909     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
910
911     nb_.setCutoff(data.cutoff_);
912     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
913     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
914     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
915
916     testIsWithin(&search, data);
917     testMinimumDistance(&search, data);
918     testNearestPoint(&search, data);
919     testPairSearch(&search, data);
920
921     search.reset();
922     testPairSearchIndexed(&nb_, data, 456);
923 }
924
925 TEST_F(NeighborhoodSearchTest, GridSearchTriclinic)
926 {
927     const NeighborhoodSearchTestData& data = RandomTriclinicFullPBCData::get();
928
929     nb_.setCutoff(data.cutoff_);
930     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
931     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
932     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
933
934     testPairSearch(&search, data);
935 }
936
937 TEST_F(NeighborhoodSearchTest, GridSearch2DPBC)
938 {
939     const NeighborhoodSearchTestData& data = RandomBox2DPBCData::get();
940
941     nb_.setCutoff(data.cutoff_);
942     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
943     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
944     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
945
946     testIsWithin(&search, data);
947     testMinimumDistance(&search, data);
948     testNearestPoint(&search, data);
949     testPairSearch(&search, data);
950 }
951
952 TEST_F(NeighborhoodSearchTest, GridSearchNoPBC)
953 {
954     const NeighborhoodSearchTestData& data = RandomBoxNoPBCData::get();
955
956     nb_.setCutoff(data.cutoff_);
957     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
958     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
959     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
960
961     testPairSearch(&search, data);
962 }
963
964 TEST_F(NeighborhoodSearchTest, GridSearchXYBox)
965 {
966     const NeighborhoodSearchTestData& data = RandomBoxXYFullPBCData::get();
967
968     nb_.setCutoff(data.cutoff_);
969     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
970     nb_.setXYMode(true);
971     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
972     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
973
974     testIsWithin(&search, data);
975     testMinimumDistance(&search, data);
976     testNearestPoint(&search, data);
977     testPairSearch(&search, data);
978 }
979
980 TEST_F(NeighborhoodSearchTest, SimpleSelfPairsSearch)
981 {
982     const NeighborhoodSearchTestData& data = TrivialSelfPairsTestData::get();
983
984     nb_.setCutoff(data.cutoff_);
985     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
986     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
987     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
988
989     testPairSearchFull(&search, data, data.testPositions(), nullptr, {}, {}, true);
990 }
991
992 TEST_F(NeighborhoodSearchTest, GridSelfPairsSearch)
993 {
994     const NeighborhoodSearchTestData& data = RandomBoxSelfPairsData::get();
995
996     nb_.setCutoff(data.cutoff_);
997     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
998     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
999     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
1000
1001     testPairSearchFull(&search, data, data.testPositions(), nullptr, {}, {}, true);
1002 }
1003
1004 TEST_F(NeighborhoodSearchTest, HandlesConcurrentSearches)
1005 {
1006     const NeighborhoodSearchTestData& data = TrivialTestData::get();
1007
1008     nb_.setCutoff(data.cutoff_);
1009     gmx::AnalysisNeighborhoodSearch search1 = nb_.initSearch(&data.pbc_, data.refPositions());
1010     gmx::AnalysisNeighborhoodSearch search2 = nb_.initSearch(&data.pbc_, data.refPositions());
1011
1012     // These checks are fragile, and unfortunately depend on the random
1013     // engine used to create the test positions. There is no particular reason
1014     // why exactly particles 0 & 2 should have neighbors, but in this case they do.
1015     gmx::AnalysisNeighborhoodPairSearch pairSearch1 = search1.startPairSearch(data.testPosition(0));
1016     gmx::AnalysisNeighborhoodPairSearch pairSearch2 = search1.startPairSearch(data.testPosition(2));
1017
1018     testPairSearch(&search2, data);
1019
1020     gmx::AnalysisNeighborhoodPair pair;
1021     ASSERT_TRUE(pairSearch1.findNextPair(&pair))
1022             << "Test data did not contain any pairs for position 0 (problem in the test).";
1023     EXPECT_EQ(0, pair.testIndex());
1024     {
1025         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), std::sqrt(pair.distance2()));
1026         EXPECT_TRUE(data.containsPair(0, searchPair));
1027     }
1028
1029     ASSERT_TRUE(pairSearch2.findNextPair(&pair))
1030             << "Test data did not contain any pairs for position 2 (problem in the test).";
1031     EXPECT_EQ(2, pair.testIndex());
1032     {
1033         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), std::sqrt(pair.distance2()));
1034         EXPECT_TRUE(data.containsPair(2, searchPair));
1035     }
1036 }
1037
1038 TEST_F(NeighborhoodSearchTest, HandlesNoPBC)
1039 {
1040     const NeighborhoodSearchTestData& data = TrivialNoPBCTestData::get();
1041
1042     nb_.setCutoff(data.cutoff_);
1043     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
1044     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
1045
1046     testIsWithin(&search, data);
1047     testMinimumDistance(&search, data);
1048     testNearestPoint(&search, data);
1049     testPairSearch(&search, data);
1050 }
1051
1052 TEST_F(NeighborhoodSearchTest, HandlesNullPBC)
1053 {
1054     const NeighborhoodSearchTestData& data = TrivialNoPBCTestData::get();
1055
1056     nb_.setCutoff(data.cutoff_);
1057     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(nullptr, data.refPositions());
1058     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
1059
1060     testIsWithin(&search, data);
1061     testMinimumDistance(&search, data);
1062     testNearestPoint(&search, data);
1063     testPairSearch(&search, data);
1064 }
1065
1066 TEST_F(NeighborhoodSearchTest, HandlesSkippingPairs)
1067 {
1068     const NeighborhoodSearchTestData& data = TrivialTestData::get();
1069
1070     nb_.setCutoff(data.cutoff_);
1071     gmx::AnalysisNeighborhoodSearch     search = nb_.initSearch(&data.pbc_, data.refPositions());
1072     gmx::AnalysisNeighborhoodPairSearch pairSearch = search.startPairSearch(data.testPositions());
1073     gmx::AnalysisNeighborhoodPair       pair;
1074     // TODO: This test needs to be adjusted if the grid search gets optimized
1075     // to loop over the test positions in cell order (first, the ordering
1076     // assumption here breaks, and second, it then needs to be tested
1077     // separately for simple and grid searches).
1078     int currentIndex = 0;
1079     while (pairSearch.findNextPair(&pair))
1080     {
1081         while (currentIndex < pair.testIndex())
1082         {
1083             ++currentIndex;
1084         }
1085         EXPECT_EQ(currentIndex, pair.testIndex());
1086         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), std::sqrt(pair.distance2()));
1087         EXPECT_TRUE(data.containsPair(currentIndex, searchPair));
1088         pairSearch.skipRemainingPairsForTestPosition();
1089         ++currentIndex;
1090     }
1091 }
1092
1093 TEST_F(NeighborhoodSearchTest, SimpleSearchExclusions)
1094 {
1095     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
1096
1097     ExclusionsHelper helper(data.refPosCount_, data.testPositions_.size());
1098     helper.generateExclusions();
1099
1100     nb_.setCutoff(data.cutoff_);
1101     nb_.setTopologyExclusions(helper.exclusions());
1102     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
1103     gmx::AnalysisNeighborhoodSearch search =
1104             nb_.initSearch(&data.pbc_, data.refPositions().exclusionIds(helper.refPosIds()));
1105     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
1106
1107     testPairSearchFull(&search,
1108                        data,
1109                        data.testPositions().exclusionIds(helper.testPosIds()),
1110                        helper.exclusions(),
1111                        {},
1112                        {},
1113                        false);
1114 }
1115
1116 TEST_F(NeighborhoodSearchTest, GridSearchExclusions)
1117 {
1118     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
1119
1120     ExclusionsHelper helper(data.refPosCount_, data.testPositions_.size());
1121     helper.generateExclusions();
1122
1123     nb_.setCutoff(data.cutoff_);
1124     nb_.setTopologyExclusions(helper.exclusions());
1125     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
1126     gmx::AnalysisNeighborhoodSearch search =
1127             nb_.initSearch(&data.pbc_, data.refPositions().exclusionIds(helper.refPosIds()));
1128     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
1129
1130     testPairSearchFull(&search,
1131                        data,
1132                        data.testPositions().exclusionIds(helper.testPosIds()),
1133                        helper.exclusions(),
1134                        {},
1135                        {},
1136                        false);
1137 }
1138
1139 } // namespace