Make PBC type enumeration into PbcType enum class
[alexxy/gromacs.git] / src / gromacs / selection / tests / nbsearch.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2013,2014,2015,2016,2017 by the GROMACS development team.
5  * Copyright (c) 2018,2019,2020, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36 /*! \internal \file
37  * \brief
38  * Tests selection neighborhood searching.
39  *
40  * \todo
41  * Increase coverage of these tests for different corner cases: other PBC cases
42  * than full 3D, large cutoffs (larger than half the box size), etc.
43  * At least some of these probably don't work correctly.
44  *
45  * \author Teemu Murtola <teemu.murtola@gmail.com>
46  * \ingroup module_selection
47  */
48 #include "gmxpre.h"
49
50 #include "gromacs/selection/nbsearch.h"
51
52 #include <cmath>
53
54 #include <algorithm>
55 #include <limits>
56 #include <map>
57 #include <numeric>
58 #include <vector>
59
60 #include <gtest/gtest.h>
61
62 #include "gromacs/math/functions.h"
63 #include "gromacs/math/vec.h"
64 #include "gromacs/pbcutil/pbc.h"
65 #include "gromacs/random/threefry.h"
66 #include "gromacs/random/uniformrealdistribution.h"
67 #include "gromacs/topology/block.h"
68 #include "gromacs/utility/listoflists.h"
69 #include "gromacs/utility/smalloc.h"
70 #include "gromacs/utility/stringutil.h"
71
72 #include "testutils/testasserts.h"
73
74
75 namespace
76 {
77
78 /********************************************************************
79  * NeighborhoodSearchTestData
80  */
81
82 class NeighborhoodSearchTestData
83 {
84 public:
85     struct RefPair
86     {
87         RefPair(int refIndex, real distance) :
88             refIndex(refIndex),
89             distance(distance),
90             bFound(false),
91             bExcluded(false),
92             bIndexed(true)
93         {
94         }
95
96         bool operator<(const RefPair& other) const { return refIndex < other.refIndex; }
97
98         int  refIndex;
99         real distance;
100         // The variables below are state variables that are only used
101         // during the actual testing after creating a copy of the reference
102         // pair list, not as part of the reference data.
103         // Simpler to have just a single structure for both purposes.
104         bool bFound;
105         bool bExcluded;
106         bool bIndexed;
107     };
108
109     struct TestPosition
110     {
111         explicit TestPosition(const rvec x) : refMinDist(0.0), refNearestPoint(-1)
112         {
113             copy_rvec(x, this->x);
114         }
115
116         rvec                 x;
117         real                 refMinDist;
118         int                  refNearestPoint;
119         std::vector<RefPair> refPairs;
120     };
121
122     typedef std::vector<TestPosition> TestPositionList;
123
124     NeighborhoodSearchTestData(uint64_t seed, real cutoff);
125
126     gmx::AnalysisNeighborhoodPositions refPositions() const
127     {
128         return gmx::AnalysisNeighborhoodPositions(refPos_);
129     }
130     gmx::AnalysisNeighborhoodPositions testPositions() const
131     {
132         if (testPos_.empty())
133         {
134             testPos_.reserve(testPositions_.size());
135             for (size_t i = 0; i < testPositions_.size(); ++i)
136             {
137                 testPos_.emplace_back(testPositions_[i].x);
138             }
139         }
140         return gmx::AnalysisNeighborhoodPositions(testPos_);
141     }
142     gmx::AnalysisNeighborhoodPositions testPosition(int index) const
143     {
144         return testPositions().selectSingleFromArray(index);
145     }
146
147     void addTestPosition(const rvec x)
148     {
149         GMX_RELEASE_ASSERT(testPos_.empty(), "Cannot add positions after testPositions() call");
150         testPositions_.emplace_back(x);
151     }
152     gmx::RVec        generateRandomPosition();
153     std::vector<int> generateIndex(int count, uint64_t seed) const;
154     void             generateRandomRefPositions(int count);
155     void             generateRandomTestPositions(int count);
156     void             useRefPositionsAsTestPositions();
157     void             computeReferences(t_pbc* pbc) { computeReferencesInternal(pbc, false); }
158     void             computeReferencesXY(t_pbc* pbc) { computeReferencesInternal(pbc, true); }
159
160     bool containsPair(int testIndex, const RefPair& pair) const
161     {
162         const std::vector<RefPair>&          refPairs = testPositions_[testIndex].refPairs;
163         std::vector<RefPair>::const_iterator foundRefPair =
164                 std::lower_bound(refPairs.begin(), refPairs.end(), pair);
165         return !(foundRefPair == refPairs.end() || foundRefPair->refIndex != pair.refIndex);
166     }
167
168     // Return a tolerance that accounts for the magnitudes of the coordinates
169     // when doing subtractions, so that we set the ULP tolerance relative to the
170     // coordinate values rather than their difference.
171     // i.e., 10.0-9.9999999 will achieve a few ULP accuracy relative
172     // to 10.0, but a much larger error relative to the difference.
173     gmx::test::FloatingPointTolerance relativeTolerance() const
174     {
175         real magnitude = std::max(box_[XX][XX], std::max(box_[YY][YY], box_[ZZ][ZZ]));
176         return gmx::test::relativeToleranceAsUlp(magnitude, 4);
177     }
178
179     gmx::DefaultRandomEngine rng_;
180     real                     cutoff_;
181     matrix                   box_;
182     t_pbc                    pbc_;
183     int                      refPosCount_;
184     std::vector<gmx::RVec>   refPos_;
185     TestPositionList         testPositions_;
186
187 private:
188     void computeReferencesInternal(t_pbc* pbc, bool bXY);
189
190     mutable std::vector<gmx::RVec> testPos_;
191 };
192
193 //! Shorthand for a collection of reference pairs.
194 typedef std::vector<NeighborhoodSearchTestData::RefPair> RefPairList;
195
196 NeighborhoodSearchTestData::NeighborhoodSearchTestData(uint64_t seed, real cutoff) :
197     rng_(seed),
198     cutoff_(cutoff),
199     refPosCount_(0)
200 {
201     clear_mat(box_);
202     set_pbc(&pbc_, PbcType::No, box_);
203 }
204
205 gmx::RVec NeighborhoodSearchTestData::generateRandomPosition()
206 {
207     gmx::UniformRealDistribution<real> dist;
208     rvec                               fx, x;
209     fx[XX] = dist(rng_);
210     fx[YY] = dist(rng_);
211     fx[ZZ] = dist(rng_);
212     mvmul(box_, fx, x);
213     // Add a small displacement to allow positions outside the box
214     x[XX] += 0.2 * dist(rng_) - 0.1;
215     x[YY] += 0.2 * dist(rng_) - 0.1;
216     x[ZZ] += 0.2 * dist(rng_) - 0.1;
217     return x;
218 }
219
220 std::vector<int> NeighborhoodSearchTestData::generateIndex(int count, uint64_t seed) const
221 {
222     gmx::DefaultRandomEngine           rngIndex(seed);
223     gmx::UniformRealDistribution<real> dist;
224     std::vector<int>                   result;
225
226     for (int i = 0; i < count; ++i)
227     {
228         if (dist(rngIndex) > 0.5)
229         {
230             result.push_back(i);
231         }
232     }
233     return result;
234 }
235
236 void NeighborhoodSearchTestData::generateRandomRefPositions(int count)
237 {
238     refPosCount_ = count;
239     refPos_.reserve(count);
240     for (int i = 0; i < count; ++i)
241     {
242         refPos_.push_back(generateRandomPosition());
243     }
244 }
245
246 void NeighborhoodSearchTestData::generateRandomTestPositions(int count)
247 {
248     testPositions_.reserve(count);
249     for (int i = 0; i < count; ++i)
250     {
251         addTestPosition(generateRandomPosition());
252     }
253 }
254
255 void NeighborhoodSearchTestData::useRefPositionsAsTestPositions()
256 {
257     testPositions_.reserve(refPosCount_);
258     for (const auto& refPos : refPos_)
259     {
260         addTestPosition(refPos);
261     }
262 }
263
264 void NeighborhoodSearchTestData::computeReferencesInternal(t_pbc* pbc, bool bXY)
265 {
266     real cutoff = cutoff_;
267     if (cutoff <= 0)
268     {
269         cutoff = std::numeric_limits<real>::max();
270     }
271     for (TestPosition& testPos : testPositions_)
272     {
273         testPos.refMinDist      = cutoff;
274         testPos.refNearestPoint = -1;
275         testPos.refPairs.clear();
276         for (int j = 0; j < refPosCount_; ++j)
277         {
278             rvec dx;
279             if (pbc != nullptr)
280             {
281                 pbc_dx(pbc, testPos.x, refPos_[j], dx);
282             }
283             else
284             {
285                 rvec_sub(testPos.x, refPos_[j], dx);
286             }
287             // TODO: This may not work intuitively for 2D with the third box
288             // vector not parallel to the Z axis, but neither does the actual
289             // neighborhood search.
290             const real dist = !bXY ? norm(dx) : std::hypot(dx[XX], dx[YY]);
291             if (dist < testPos.refMinDist)
292             {
293                 testPos.refMinDist      = dist;
294                 testPos.refNearestPoint = j;
295             }
296             if (dist > 0 && dist <= cutoff)
297             {
298                 RefPair pair(j, dist);
299                 GMX_RELEASE_ASSERT(testPos.refPairs.empty() || testPos.refPairs.back() < pair,
300                                    "Reference pairs should be generated in sorted order");
301                 testPos.refPairs.push_back(pair);
302             }
303         }
304     }
305 }
306
307 /********************************************************************
308  * ExclusionsHelper
309  */
310
311 class ExclusionsHelper
312 {
313 public:
314     static void markExcludedPairs(RefPairList* refPairs, int testIndex, const gmx::ListOfLists<int>* excls);
315
316     ExclusionsHelper(int refPosCount, int testPosCount);
317
318     void generateExclusions();
319
320     const gmx::ListOfLists<int>* exclusions() const { return &excls_; }
321
322     gmx::ArrayRef<const int> refPosIds() const
323     {
324         return gmx::makeArrayRef(exclusionIds_).subArray(0, refPosCount_);
325     }
326     gmx::ArrayRef<const int> testPosIds() const
327     {
328         return gmx::makeArrayRef(exclusionIds_).subArray(0, testPosCount_);
329     }
330
331 private:
332     int                   refPosCount_;
333     int                   testPosCount_;
334     std::vector<int>      exclusionIds_;
335     gmx::ListOfLists<int> excls_;
336 };
337
338 // static
339 void ExclusionsHelper::markExcludedPairs(RefPairList* refPairs, int testIndex, const gmx::ListOfLists<int>* excls)
340 {
341     int count = 0;
342     for (const int excludedIndex : (*excls)[testIndex])
343     {
344         NeighborhoodSearchTestData::RefPair searchPair(excludedIndex, 0.0);
345         RefPairList::iterator               excludedRefPair =
346                 std::lower_bound(refPairs->begin(), refPairs->end(), searchPair);
347         if (excludedRefPair != refPairs->end() && excludedRefPair->refIndex == excludedIndex)
348         {
349             excludedRefPair->bFound    = true;
350             excludedRefPair->bExcluded = true;
351             ++count;
352         }
353     }
354 }
355
356 ExclusionsHelper::ExclusionsHelper(int refPosCount, int testPosCount) :
357     refPosCount_(refPosCount),
358     testPosCount_(testPosCount)
359 {
360     // Generate an array of 0, 1, 2, ...
361     // TODO: Make the tests work also with non-trivial exclusion IDs,
362     // and test that.
363     exclusionIds_.resize(std::max(refPosCount, testPosCount), 1);
364     exclusionIds_[0] = 0;
365     std::partial_sum(exclusionIds_.begin(), exclusionIds_.end(), exclusionIds_.begin());
366 }
367
368 void ExclusionsHelper::generateExclusions()
369 {
370     // TODO: Consider a better set of test data, where the density of the
371     // particles would be higher, or where the exclusions would not be random,
372     // to make a higher percentage of the exclusions to actually be within the
373     // cutoff.
374     for (int i = 0; i < testPosCount_; ++i)
375     {
376         excls_.pushBackListOfSize(20);
377         gmx::ArrayRef<int> exclusionsForAtom = excls_.back();
378         for (int j = 0; j < 20; ++j)
379         {
380             exclusionsForAtom[j] = i + j * 3;
381         }
382     }
383 }
384
385 /********************************************************************
386  * NeighborhoodSearchTest
387  */
388
389 class NeighborhoodSearchTest : public ::testing::Test
390 {
391 public:
392     void testIsWithin(gmx::AnalysisNeighborhoodSearch* search, const NeighborhoodSearchTestData& data);
393     void testMinimumDistance(gmx::AnalysisNeighborhoodSearch*  search,
394                              const NeighborhoodSearchTestData& data);
395     void testNearestPoint(gmx::AnalysisNeighborhoodSearch* search, const NeighborhoodSearchTestData& data);
396     void testPairSearch(gmx::AnalysisNeighborhoodSearch* search, const NeighborhoodSearchTestData& data);
397     void testPairSearchIndexed(gmx::AnalysisNeighborhood*        nb,
398                                const NeighborhoodSearchTestData& data,
399                                uint64_t                          seed);
400     void testPairSearchFull(gmx::AnalysisNeighborhoodSearch*          search,
401                             const NeighborhoodSearchTestData&         data,
402                             const gmx::AnalysisNeighborhoodPositions& pos,
403                             const gmx::ListOfLists<int>*              excls,
404                             const gmx::ArrayRef<const int>&           refIndices,
405                             const gmx::ArrayRef<const int>&           testIndices,
406                             bool                                      selfPairs);
407
408     gmx::AnalysisNeighborhood nb_;
409 };
410
411 void NeighborhoodSearchTest::testIsWithin(gmx::AnalysisNeighborhoodSearch*  search,
412                                           const NeighborhoodSearchTestData& data)
413 {
414     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
415     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
416     {
417         const bool bWithin = (i->refMinDist <= data.cutoff_);
418         EXPECT_EQ(bWithin, search->isWithin(i->x)) << "Distance is " << i->refMinDist;
419     }
420 }
421
422 void NeighborhoodSearchTest::testMinimumDistance(gmx::AnalysisNeighborhoodSearch*  search,
423                                                  const NeighborhoodSearchTestData& data)
424 {
425     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
426
427     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
428     {
429         const real refDist = i->refMinDist;
430         EXPECT_REAL_EQ_TOL(refDist, search->minimumDistance(i->x), data.relativeTolerance());
431     }
432 }
433
434 void NeighborhoodSearchTest::testNearestPoint(gmx::AnalysisNeighborhoodSearch*  search,
435                                               const NeighborhoodSearchTestData& data)
436 {
437     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
438     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
439     {
440         const gmx::AnalysisNeighborhoodPair pair = search->nearestPoint(i->x);
441         if (pair.isValid())
442         {
443             EXPECT_EQ(i->refNearestPoint, pair.refIndex());
444             EXPECT_EQ(0, pair.testIndex());
445             EXPECT_REAL_EQ_TOL(i->refMinDist, std::sqrt(pair.distance2()), data.relativeTolerance());
446         }
447         else
448         {
449             EXPECT_EQ(i->refNearestPoint, -1);
450         }
451     }
452 }
453
454 //! Helper function for formatting test failure messages.
455 std::string formatVector(const rvec x)
456 {
457     return gmx::formatString("[%.3f, %.3f, %.3f]", x[XX], x[YY], x[ZZ]);
458 }
459
460 /*! \brief
461  * Helper function to check that all expected pairs were found.
462  */
463 void checkAllPairsFound(const RefPairList&            refPairs,
464                         const std::vector<gmx::RVec>& refPos,
465                         int                           testPosIndex,
466                         const rvec                    testPos)
467 {
468     // This could be elegantly expressed with Google Mock matchers, but that
469     // has a significant effect on the runtime of the tests...
470     int                         count = 0;
471     RefPairList::const_iterator first;
472     for (RefPairList::const_iterator i = refPairs.begin(); i != refPairs.end(); ++i)
473     {
474         if (!i->bFound)
475         {
476             ++count;
477             first = i;
478         }
479     }
480     if (count > 0)
481     {
482         ADD_FAILURE() << "Some pairs (" << count << "/" << refPairs.size() << ") "
483                       << "within the cutoff were not found. First pair:\n"
484                       << " Ref: " << first->refIndex << " at "
485                       << formatVector(refPos[first->refIndex]) << "\n"
486                       << "Test: " << testPosIndex << " at " << formatVector(testPos) << "\n"
487                       << "Dist: " << first->distance;
488     }
489 }
490
491 void NeighborhoodSearchTest::testPairSearch(gmx::AnalysisNeighborhoodSearch*  search,
492                                             const NeighborhoodSearchTestData& data)
493 {
494     testPairSearchFull(search, data, data.testPositions(), nullptr, {}, {}, false);
495 }
496
497 void NeighborhoodSearchTest::testPairSearchIndexed(gmx::AnalysisNeighborhood*        nb,
498                                                    const NeighborhoodSearchTestData& data,
499                                                    uint64_t                          seed)
500 {
501     std::vector<int> refIndices(data.generateIndex(data.refPos_.size(), seed++));
502     std::vector<int> testIndices(data.generateIndex(data.testPositions_.size(), seed++));
503     gmx::AnalysisNeighborhoodSearch search =
504             nb->initSearch(&data.pbc_, data.refPositions().indexed(refIndices));
505     testPairSearchFull(&search, data, data.testPositions(), nullptr, refIndices, testIndices, false);
506 }
507
508 void NeighborhoodSearchTest::testPairSearchFull(gmx::AnalysisNeighborhoodSearch*          search,
509                                                 const NeighborhoodSearchTestData&         data,
510                                                 const gmx::AnalysisNeighborhoodPositions& pos,
511                                                 const gmx::ListOfLists<int>*              excls,
512                                                 const gmx::ArrayRef<const int>& refIndices,
513                                                 const gmx::ArrayRef<const int>& testIndices,
514                                                 bool                            selfPairs)
515 {
516     std::map<int, RefPairList> refPairs;
517     // TODO: Some parts of this code do not work properly if pos does not
518     // initially contain all the test positions.
519     if (testIndices.empty())
520     {
521         for (size_t i = 0; i < data.testPositions_.size(); ++i)
522         {
523             refPairs[i] = data.testPositions_[i].refPairs;
524         }
525     }
526     else
527     {
528         for (int index : testIndices)
529         {
530             refPairs[index] = data.testPositions_[index].refPairs;
531         }
532     }
533     if (excls != nullptr)
534     {
535         GMX_RELEASE_ASSERT(!selfPairs, "Self-pairs testing not implemented with exclusions");
536         for (auto& entry : refPairs)
537         {
538             const int testIndex = entry.first;
539             ExclusionsHelper::markExcludedPairs(&entry.second, testIndex, excls);
540         }
541     }
542     if (!refIndices.empty())
543     {
544         GMX_RELEASE_ASSERT(!selfPairs, "Self-pairs testing not implemented with indexing");
545         for (auto& entry : refPairs)
546         {
547             for (auto& refPair : entry.second)
548             {
549                 refPair.bIndexed = false;
550             }
551             for (int index : refIndices)
552             {
553                 NeighborhoodSearchTestData::RefPair searchPair(index, 0.0);
554                 auto refPair = std::lower_bound(entry.second.begin(), entry.second.end(), searchPair);
555                 if (refPair != entry.second.end() && refPair->refIndex == index)
556                 {
557                     refPair->bIndexed = true;
558                 }
559             }
560             for (auto& refPair : entry.second)
561             {
562                 if (!refPair.bIndexed)
563                 {
564                     refPair.bFound = true;
565                 }
566             }
567         }
568     }
569
570     gmx::AnalysisNeighborhoodPositions posCopy(pos);
571     if (!testIndices.empty())
572     {
573         posCopy.indexed(testIndices);
574     }
575     gmx::AnalysisNeighborhoodPairSearch pairSearch =
576             selfPairs ? search->startSelfPairSearch() : search->startPairSearch(posCopy);
577     gmx::AnalysisNeighborhoodPair pair;
578     while (pairSearch.findNextPair(&pair))
579     {
580         const int testIndex = (testIndices.empty() ? pair.testIndex() : testIndices[pair.testIndex()]);
581         const int refIndex = (refIndices.empty() ? pair.refIndex() : refIndices[pair.refIndex()]);
582
583         if (refPairs.count(testIndex) == 0)
584         {
585             ADD_FAILURE() << "Expected: No pairs are returned for test position " << testIndex << ".\n"
586                           << "  Actual: Pair with ref " << refIndex << " is returned.";
587             continue;
588         }
589         NeighborhoodSearchTestData::RefPair searchPair(refIndex, std::sqrt(pair.distance2()));
590         const auto                          foundRefPair =
591                 std::lower_bound(refPairs[testIndex].begin(), refPairs[testIndex].end(), searchPair);
592         if (foundRefPair == refPairs[testIndex].end() || foundRefPair->refIndex != refIndex)
593         {
594             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
595                           << ") is not within the cutoff.\n"
596                           << "  Actual: It is returned.";
597         }
598         else if (foundRefPair->bExcluded)
599         {
600             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
601                           << ") is excluded from the search.\n"
602                           << "  Actual: It is returned.";
603         }
604         else if (!foundRefPair->bIndexed)
605         {
606             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
607                           << ") is not part of the indexed set.\n"
608                           << "  Actual: It is returned.";
609         }
610         else if (foundRefPair->bFound)
611         {
612             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
613                           << ") is returned only once.\n"
614                           << "  Actual: It is returned multiple times.";
615             return;
616         }
617         else
618         {
619             foundRefPair->bFound = true;
620
621             EXPECT_REAL_EQ_TOL(foundRefPair->distance, searchPair.distance, data.relativeTolerance())
622                     << "Distance computed by the neighborhood search does not match.";
623             if (selfPairs)
624             {
625                 searchPair              = NeighborhoodSearchTestData::RefPair(testIndex, 0.0);
626                 const auto otherRefPair = std::lower_bound(refPairs[refIndex].begin(),
627                                                            refPairs[refIndex].end(), searchPair);
628                 GMX_RELEASE_ASSERT(otherRefPair != refPairs[refIndex].end(),
629                                    "Precomputed reference data is not symmetric");
630                 otherRefPair->bFound = true;
631             }
632         }
633     }
634
635     for (auto& entry : refPairs)
636     {
637         const int testIndex = entry.first;
638         checkAllPairsFound(entry.second, data.refPos_, testIndex, data.testPositions_[testIndex].x);
639     }
640 }
641
642 /********************************************************************
643  * Test data generation
644  */
645
646 class TrivialTestData
647 {
648 public:
649     static const NeighborhoodSearchTestData& get()
650     {
651         static TrivialTestData singleton;
652         return singleton.data_;
653     }
654
655     TrivialTestData() : data_(12345, 1.0)
656     {
657         // Make the box so small we are virtually guaranteed to have
658         // several neighbors for the five test positions
659         data_.box_[XX][XX] = 3.0;
660         data_.box_[YY][YY] = 3.0;
661         data_.box_[ZZ][ZZ] = 3.0;
662         data_.generateRandomRefPositions(10);
663         data_.generateRandomTestPositions(5);
664         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
665         data_.computeReferences(&data_.pbc_);
666     }
667
668 private:
669     NeighborhoodSearchTestData data_;
670 };
671
672 class TrivialSelfPairsTestData
673 {
674 public:
675     static const NeighborhoodSearchTestData& get()
676     {
677         static TrivialSelfPairsTestData singleton;
678         return singleton.data_;
679     }
680
681     TrivialSelfPairsTestData() : data_(12345, 1.0)
682     {
683         data_.box_[XX][XX] = 3.0;
684         data_.box_[YY][YY] = 3.0;
685         data_.box_[ZZ][ZZ] = 3.0;
686         data_.generateRandomRefPositions(20);
687         data_.useRefPositionsAsTestPositions();
688         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
689         data_.computeReferences(&data_.pbc_);
690     }
691
692 private:
693     NeighborhoodSearchTestData data_;
694 };
695
696 class TrivialNoPBCTestData
697 {
698 public:
699     static const NeighborhoodSearchTestData& get()
700     {
701         static TrivialNoPBCTestData singleton;
702         return singleton.data_;
703     }
704
705     TrivialNoPBCTestData() : data_(12345, 1.0)
706     {
707         data_.generateRandomRefPositions(10);
708         data_.generateRandomTestPositions(5);
709         data_.computeReferences(nullptr);
710     }
711
712 private:
713     NeighborhoodSearchTestData data_;
714 };
715
716 class RandomBoxFullPBCData
717 {
718 public:
719     static const NeighborhoodSearchTestData& get()
720     {
721         static RandomBoxFullPBCData singleton;
722         return singleton.data_;
723     }
724
725     RandomBoxFullPBCData() : data_(12345, 1.0)
726     {
727         data_.box_[XX][XX] = 10.0;
728         data_.box_[YY][YY] = 5.0;
729         data_.box_[ZZ][ZZ] = 7.0;
730         // TODO: Consider whether manually picking some positions would give better
731         // test coverage.
732         data_.generateRandomRefPositions(1000);
733         data_.generateRandomTestPositions(100);
734         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
735         data_.computeReferences(&data_.pbc_);
736     }
737
738 private:
739     NeighborhoodSearchTestData data_;
740 };
741
742 class RandomBoxSelfPairsData
743 {
744 public:
745     static const NeighborhoodSearchTestData& get()
746     {
747         static RandomBoxSelfPairsData singleton;
748         return singleton.data_;
749     }
750
751     RandomBoxSelfPairsData() : data_(12345, 1.0)
752     {
753         data_.box_[XX][XX] = 10.0;
754         data_.box_[YY][YY] = 5.0;
755         data_.box_[ZZ][ZZ] = 7.0;
756         data_.generateRandomRefPositions(1000);
757         data_.useRefPositionsAsTestPositions();
758         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
759         data_.computeReferences(&data_.pbc_);
760     }
761
762 private:
763     NeighborhoodSearchTestData data_;
764 };
765
766 class RandomBoxXYFullPBCData
767 {
768 public:
769     static const NeighborhoodSearchTestData& get()
770     {
771         static RandomBoxXYFullPBCData singleton;
772         return singleton.data_;
773     }
774
775     RandomBoxXYFullPBCData() : data_(54321, 1.0)
776     {
777         data_.box_[XX][XX] = 10.0;
778         data_.box_[YY][YY] = 5.0;
779         data_.box_[ZZ][ZZ] = 7.0;
780         // TODO: Consider whether manually picking some positions would give better
781         // test coverage.
782         data_.generateRandomRefPositions(1000);
783         data_.generateRandomTestPositions(100);
784         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
785         data_.computeReferencesXY(&data_.pbc_);
786     }
787
788 private:
789     NeighborhoodSearchTestData data_;
790 };
791
792 class RandomTriclinicFullPBCData
793 {
794 public:
795     static const NeighborhoodSearchTestData& get()
796     {
797         static RandomTriclinicFullPBCData singleton;
798         return singleton.data_;
799     }
800
801     RandomTriclinicFullPBCData() : data_(12345, 1.0)
802     {
803         data_.box_[XX][XX] = 5.0;
804         data_.box_[YY][XX] = 2.5;
805         data_.box_[YY][YY] = 2.5 * std::sqrt(3.0);
806         data_.box_[ZZ][XX] = 2.5;
807         data_.box_[ZZ][YY] = 2.5 * std::sqrt(1.0 / 3.0);
808         data_.box_[ZZ][ZZ] = 5.0 * std::sqrt(2.0 / 3.0);
809         // TODO: Consider whether manually picking some positions would give better
810         // test coverage.
811         data_.generateRandomRefPositions(1000);
812         data_.generateRandomTestPositions(100);
813         set_pbc(&data_.pbc_, PbcType::Xyz, data_.box_);
814         data_.computeReferences(&data_.pbc_);
815     }
816
817 private:
818     NeighborhoodSearchTestData data_;
819 };
820
821 class RandomBox2DPBCData
822 {
823 public:
824     static const NeighborhoodSearchTestData& get()
825     {
826         static RandomBox2DPBCData singleton;
827         return singleton.data_;
828     }
829
830     RandomBox2DPBCData() : data_(12345, 1.0)
831     {
832         data_.box_[XX][XX] = 10.0;
833         data_.box_[YY][YY] = 7.0;
834         data_.box_[ZZ][ZZ] = 5.0;
835         // TODO: Consider whether manually picking some positions would give better
836         // test coverage.
837         data_.generateRandomRefPositions(1000);
838         data_.generateRandomTestPositions(100);
839         set_pbc(&data_.pbc_, PbcType::XY, data_.box_);
840         data_.computeReferences(&data_.pbc_);
841     }
842
843 private:
844     NeighborhoodSearchTestData data_;
845 };
846
847 class RandomBoxNoPBCData
848 {
849 public:
850     static const NeighborhoodSearchTestData& get()
851     {
852         static RandomBoxNoPBCData singleton;
853         return singleton.data_;
854     }
855
856     RandomBoxNoPBCData() : data_(12345, 1.0)
857     {
858         data_.box_[XX][XX] = 10.0;
859         data_.box_[YY][YY] = 5.0;
860         data_.box_[ZZ][ZZ] = 7.0;
861         // TODO: Consider whether manually picking some positions would give better
862         // test coverage.
863         data_.generateRandomRefPositions(1000);
864         data_.generateRandomTestPositions(100);
865         set_pbc(&data_.pbc_, PbcType::No, data_.box_);
866         data_.computeReferences(nullptr);
867     }
868
869 private:
870     NeighborhoodSearchTestData data_;
871 };
872
873 /********************************************************************
874  * Actual tests
875  */
876
877 TEST_F(NeighborhoodSearchTest, SimpleSearch)
878 {
879     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
880
881     nb_.setCutoff(data.cutoff_);
882     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
883     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
884     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
885
886     testIsWithin(&search, data);
887     testMinimumDistance(&search, data);
888     testNearestPoint(&search, data);
889     testPairSearch(&search, data);
890
891     search.reset();
892     testPairSearchIndexed(&nb_, data, 123);
893 }
894
895 TEST_F(NeighborhoodSearchTest, SimpleSearchXY)
896 {
897     const NeighborhoodSearchTestData& data = RandomBoxXYFullPBCData::get();
898
899     nb_.setCutoff(data.cutoff_);
900     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
901     nb_.setXYMode(true);
902     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
903     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
904
905     testIsWithin(&search, data);
906     testMinimumDistance(&search, data);
907     testNearestPoint(&search, data);
908     testPairSearch(&search, data);
909 }
910
911 TEST_F(NeighborhoodSearchTest, GridSearchBox)
912 {
913     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
914
915     nb_.setCutoff(data.cutoff_);
916     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
917     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
918     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
919
920     testIsWithin(&search, data);
921     testMinimumDistance(&search, data);
922     testNearestPoint(&search, data);
923     testPairSearch(&search, data);
924
925     search.reset();
926     testPairSearchIndexed(&nb_, data, 456);
927 }
928
929 TEST_F(NeighborhoodSearchTest, GridSearchTriclinic)
930 {
931     const NeighborhoodSearchTestData& data = RandomTriclinicFullPBCData::get();
932
933     nb_.setCutoff(data.cutoff_);
934     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
935     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
936     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
937
938     testPairSearch(&search, data);
939 }
940
941 TEST_F(NeighborhoodSearchTest, GridSearch2DPBC)
942 {
943     const NeighborhoodSearchTestData& data = RandomBox2DPBCData::get();
944
945     nb_.setCutoff(data.cutoff_);
946     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
947     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
948     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
949
950     testIsWithin(&search, data);
951     testMinimumDistance(&search, data);
952     testNearestPoint(&search, data);
953     testPairSearch(&search, data);
954 }
955
956 TEST_F(NeighborhoodSearchTest, GridSearchNoPBC)
957 {
958     const NeighborhoodSearchTestData& data = RandomBoxNoPBCData::get();
959
960     nb_.setCutoff(data.cutoff_);
961     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
962     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
963     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
964
965     testPairSearch(&search, data);
966 }
967
968 TEST_F(NeighborhoodSearchTest, GridSearchXYBox)
969 {
970     const NeighborhoodSearchTestData& data = RandomBoxXYFullPBCData::get();
971
972     nb_.setCutoff(data.cutoff_);
973     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
974     nb_.setXYMode(true);
975     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
976     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
977
978     testIsWithin(&search, data);
979     testMinimumDistance(&search, data);
980     testNearestPoint(&search, data);
981     testPairSearch(&search, data);
982 }
983
984 TEST_F(NeighborhoodSearchTest, SimpleSelfPairsSearch)
985 {
986     const NeighborhoodSearchTestData& data = TrivialSelfPairsTestData::get();
987
988     nb_.setCutoff(data.cutoff_);
989     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
990     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
991     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
992
993     testPairSearchFull(&search, data, data.testPositions(), nullptr, {}, {}, true);
994 }
995
996 TEST_F(NeighborhoodSearchTest, GridSelfPairsSearch)
997 {
998     const NeighborhoodSearchTestData& data = RandomBoxSelfPairsData::get();
999
1000     nb_.setCutoff(data.cutoff_);
1001     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
1002     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
1003     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
1004
1005     testPairSearchFull(&search, data, data.testPositions(), nullptr, {}, {}, true);
1006 }
1007
1008 TEST_F(NeighborhoodSearchTest, HandlesConcurrentSearches)
1009 {
1010     const NeighborhoodSearchTestData& data = TrivialTestData::get();
1011
1012     nb_.setCutoff(data.cutoff_);
1013     gmx::AnalysisNeighborhoodSearch search1 = nb_.initSearch(&data.pbc_, data.refPositions());
1014     gmx::AnalysisNeighborhoodSearch search2 = nb_.initSearch(&data.pbc_, data.refPositions());
1015
1016     // These checks are fragile, and unfortunately depend on the random
1017     // engine used to create the test positions. There is no particular reason
1018     // why exactly particles 0 & 2 should have neighbors, but in this case they do.
1019     gmx::AnalysisNeighborhoodPairSearch pairSearch1 = search1.startPairSearch(data.testPosition(0));
1020     gmx::AnalysisNeighborhoodPairSearch pairSearch2 = search1.startPairSearch(data.testPosition(2));
1021
1022     testPairSearch(&search2, data);
1023
1024     gmx::AnalysisNeighborhoodPair pair;
1025     ASSERT_TRUE(pairSearch1.findNextPair(&pair))
1026             << "Test data did not contain any pairs for position 0 (problem in the test).";
1027     EXPECT_EQ(0, pair.testIndex());
1028     {
1029         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), std::sqrt(pair.distance2()));
1030         EXPECT_TRUE(data.containsPair(0, searchPair));
1031     }
1032
1033     ASSERT_TRUE(pairSearch2.findNextPair(&pair))
1034             << "Test data did not contain any pairs for position 2 (problem in the test).";
1035     EXPECT_EQ(2, pair.testIndex());
1036     {
1037         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), std::sqrt(pair.distance2()));
1038         EXPECT_TRUE(data.containsPair(2, searchPair));
1039     }
1040 }
1041
1042 TEST_F(NeighborhoodSearchTest, HandlesNoPBC)
1043 {
1044     const NeighborhoodSearchTestData& data = TrivialNoPBCTestData::get();
1045
1046     nb_.setCutoff(data.cutoff_);
1047     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
1048     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
1049
1050     testIsWithin(&search, data);
1051     testMinimumDistance(&search, data);
1052     testNearestPoint(&search, data);
1053     testPairSearch(&search, data);
1054 }
1055
1056 TEST_F(NeighborhoodSearchTest, HandlesNullPBC)
1057 {
1058     const NeighborhoodSearchTestData& data = TrivialNoPBCTestData::get();
1059
1060     nb_.setCutoff(data.cutoff_);
1061     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(nullptr, data.refPositions());
1062     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
1063
1064     testIsWithin(&search, data);
1065     testMinimumDistance(&search, data);
1066     testNearestPoint(&search, data);
1067     testPairSearch(&search, data);
1068 }
1069
1070 TEST_F(NeighborhoodSearchTest, HandlesSkippingPairs)
1071 {
1072     const NeighborhoodSearchTestData& data = TrivialTestData::get();
1073
1074     nb_.setCutoff(data.cutoff_);
1075     gmx::AnalysisNeighborhoodSearch     search = nb_.initSearch(&data.pbc_, data.refPositions());
1076     gmx::AnalysisNeighborhoodPairSearch pairSearch = search.startPairSearch(data.testPositions());
1077     gmx::AnalysisNeighborhoodPair       pair;
1078     // TODO: This test needs to be adjusted if the grid search gets optimized
1079     // to loop over the test positions in cell order (first, the ordering
1080     // assumption here breaks, and second, it then needs to be tested
1081     // separately for simple and grid searches).
1082     int currentIndex = 0;
1083     while (pairSearch.findNextPair(&pair))
1084     {
1085         while (currentIndex < pair.testIndex())
1086         {
1087             ++currentIndex;
1088         }
1089         EXPECT_EQ(currentIndex, pair.testIndex());
1090         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), std::sqrt(pair.distance2()));
1091         EXPECT_TRUE(data.containsPair(currentIndex, searchPair));
1092         pairSearch.skipRemainingPairsForTestPosition();
1093         ++currentIndex;
1094     }
1095 }
1096
1097 TEST_F(NeighborhoodSearchTest, SimpleSearchExclusions)
1098 {
1099     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
1100
1101     ExclusionsHelper helper(data.refPosCount_, data.testPositions_.size());
1102     helper.generateExclusions();
1103
1104     nb_.setCutoff(data.cutoff_);
1105     nb_.setTopologyExclusions(helper.exclusions());
1106     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
1107     gmx::AnalysisNeighborhoodSearch search =
1108             nb_.initSearch(&data.pbc_, data.refPositions().exclusionIds(helper.refPosIds()));
1109     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
1110
1111     testPairSearchFull(&search, data, data.testPositions().exclusionIds(helper.testPosIds()),
1112                        helper.exclusions(), {}, {}, false);
1113 }
1114
1115 TEST_F(NeighborhoodSearchTest, GridSearchExclusions)
1116 {
1117     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
1118
1119     ExclusionsHelper helper(data.refPosCount_, data.testPositions_.size());
1120     helper.generateExclusions();
1121
1122     nb_.setCutoff(data.cutoff_);
1123     nb_.setTopologyExclusions(helper.exclusions());
1124     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
1125     gmx::AnalysisNeighborhoodSearch search =
1126             nb_.initSearch(&data.pbc_, data.refPositions().exclusionIds(helper.refPosIds()));
1127     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
1128
1129     testPairSearchFull(&search, data, data.testPositions().exclusionIds(helper.testPosIds()),
1130                        helper.exclusions(), {}, {}, false);
1131 }
1132
1133 } // namespace