Use ListOfLists in gmx_mtop_t and gmx_localtop_t
[alexxy/gromacs.git] / src / gromacs / selection / tests / nbsearch.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2013,2014,2015,2016,2017,2018,2019, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  * \brief
37  * Tests selection neighborhood searching.
38  *
39  * \todo
40  * Increase coverage of these tests for different corner cases: other PBC cases
41  * than full 3D, large cutoffs (larger than half the box size), etc.
42  * At least some of these probably don't work correctly.
43  *
44  * \author Teemu Murtola <teemu.murtola@gmail.com>
45  * \ingroup module_selection
46  */
47 #include "gmxpre.h"
48
49 #include "gromacs/selection/nbsearch.h"
50
51 #include <cmath>
52
53 #include <algorithm>
54 #include <limits>
55 #include <map>
56 #include <numeric>
57 #include <vector>
58
59 #include <gtest/gtest.h>
60
61 #include "gromacs/math/functions.h"
62 #include "gromacs/math/vec.h"
63 #include "gromacs/pbcutil/pbc.h"
64 #include "gromacs/random/threefry.h"
65 #include "gromacs/random/uniformrealdistribution.h"
66 #include "gromacs/topology/block.h"
67 #include "gromacs/utility/listoflists.h"
68 #include "gromacs/utility/smalloc.h"
69 #include "gromacs/utility/stringutil.h"
70
71 #include "testutils/testasserts.h"
72
73
74 namespace
75 {
76
77 /********************************************************************
78  * NeighborhoodSearchTestData
79  */
80
81 class NeighborhoodSearchTestData
82 {
83 public:
84     struct RefPair
85     {
86         RefPair(int refIndex, real distance) :
87             refIndex(refIndex),
88             distance(distance),
89             bFound(false),
90             bExcluded(false),
91             bIndexed(true)
92         {
93         }
94
95         bool operator<(const RefPair& other) const { return refIndex < other.refIndex; }
96
97         int  refIndex;
98         real distance;
99         // The variables below are state variables that are only used
100         // during the actual testing after creating a copy of the reference
101         // pair list, not as part of the reference data.
102         // Simpler to have just a single structure for both purposes.
103         bool bFound;
104         bool bExcluded;
105         bool bIndexed;
106     };
107
108     struct TestPosition
109     {
110         explicit TestPosition(const rvec x) : refMinDist(0.0), refNearestPoint(-1)
111         {
112             copy_rvec(x, this->x);
113         }
114
115         rvec                 x;
116         real                 refMinDist;
117         int                  refNearestPoint;
118         std::vector<RefPair> refPairs;
119     };
120
121     typedef std::vector<TestPosition> TestPositionList;
122
123     NeighborhoodSearchTestData(uint64_t seed, real cutoff);
124
125     gmx::AnalysisNeighborhoodPositions refPositions() const
126     {
127         return gmx::AnalysisNeighborhoodPositions(refPos_);
128     }
129     gmx::AnalysisNeighborhoodPositions testPositions() const
130     {
131         if (testPos_.empty())
132         {
133             testPos_.reserve(testPositions_.size());
134             for (size_t i = 0; i < testPositions_.size(); ++i)
135             {
136                 testPos_.emplace_back(testPositions_[i].x);
137             }
138         }
139         return gmx::AnalysisNeighborhoodPositions(testPos_);
140     }
141     gmx::AnalysisNeighborhoodPositions testPosition(int index) const
142     {
143         return testPositions().selectSingleFromArray(index);
144     }
145
146     void addTestPosition(const rvec x)
147     {
148         GMX_RELEASE_ASSERT(testPos_.empty(), "Cannot add positions after testPositions() call");
149         testPositions_.emplace_back(x);
150     }
151     gmx::RVec        generateRandomPosition();
152     std::vector<int> generateIndex(int count, uint64_t seed) const;
153     void             generateRandomRefPositions(int count);
154     void             generateRandomTestPositions(int count);
155     void             useRefPositionsAsTestPositions();
156     void             computeReferences(t_pbc* pbc) { computeReferencesInternal(pbc, false); }
157     void             computeReferencesXY(t_pbc* pbc) { computeReferencesInternal(pbc, true); }
158
159     bool containsPair(int testIndex, const RefPair& pair) const
160     {
161         const std::vector<RefPair>&          refPairs = testPositions_[testIndex].refPairs;
162         std::vector<RefPair>::const_iterator foundRefPair =
163                 std::lower_bound(refPairs.begin(), refPairs.end(), pair);
164         return !(foundRefPair == refPairs.end() || foundRefPair->refIndex != pair.refIndex);
165     }
166
167     // Return a tolerance that accounts for the magnitudes of the coordinates
168     // when doing subtractions, so that we set the ULP tolerance relative to the
169     // coordinate values rather than their difference.
170     // i.e., 10.0-9.9999999 will achieve a few ULP accuracy relative
171     // to 10.0, but a much larger error relative to the difference.
172     gmx::test::FloatingPointTolerance relativeTolerance() const
173     {
174         real magnitude = std::max(box_[XX][XX], std::max(box_[YY][YY], box_[ZZ][ZZ]));
175         return gmx::test::relativeToleranceAsUlp(magnitude, 4);
176     }
177
178     gmx::DefaultRandomEngine rng_;
179     real                     cutoff_;
180     matrix                   box_;
181     t_pbc                    pbc_;
182     int                      refPosCount_;
183     std::vector<gmx::RVec>   refPos_;
184     TestPositionList         testPositions_;
185
186 private:
187     void computeReferencesInternal(t_pbc* pbc, bool bXY);
188
189     mutable std::vector<gmx::RVec> testPos_;
190 };
191
192 //! Shorthand for a collection of reference pairs.
193 typedef std::vector<NeighborhoodSearchTestData::RefPair> RefPairList;
194
195 NeighborhoodSearchTestData::NeighborhoodSearchTestData(uint64_t seed, real cutoff) :
196     rng_(seed),
197     cutoff_(cutoff),
198     refPosCount_(0)
199 {
200     clear_mat(box_);
201     set_pbc(&pbc_, epbcNONE, box_);
202 }
203
204 gmx::RVec NeighborhoodSearchTestData::generateRandomPosition()
205 {
206     gmx::UniformRealDistribution<real> dist;
207     rvec                               fx, x;
208     fx[XX] = dist(rng_);
209     fx[YY] = dist(rng_);
210     fx[ZZ] = dist(rng_);
211     mvmul(box_, fx, x);
212     // Add a small displacement to allow positions outside the box
213     x[XX] += 0.2 * dist(rng_) - 0.1;
214     x[YY] += 0.2 * dist(rng_) - 0.1;
215     x[ZZ] += 0.2 * dist(rng_) - 0.1;
216     return x;
217 }
218
219 std::vector<int> NeighborhoodSearchTestData::generateIndex(int count, uint64_t seed) const
220 {
221     gmx::DefaultRandomEngine           rngIndex(seed);
222     gmx::UniformRealDistribution<real> dist;
223     std::vector<int>                   result;
224
225     for (int i = 0; i < count; ++i)
226     {
227         if (dist(rngIndex) > 0.5)
228         {
229             result.push_back(i);
230         }
231     }
232     return result;
233 }
234
235 void NeighborhoodSearchTestData::generateRandomRefPositions(int count)
236 {
237     refPosCount_ = count;
238     refPos_.reserve(count);
239     for (int i = 0; i < count; ++i)
240     {
241         refPos_.push_back(generateRandomPosition());
242     }
243 }
244
245 void NeighborhoodSearchTestData::generateRandomTestPositions(int count)
246 {
247     testPositions_.reserve(count);
248     for (int i = 0; i < count; ++i)
249     {
250         addTestPosition(generateRandomPosition());
251     }
252 }
253
254 void NeighborhoodSearchTestData::useRefPositionsAsTestPositions()
255 {
256     testPositions_.reserve(refPosCount_);
257     for (const auto& refPos : refPos_)
258     {
259         addTestPosition(refPos);
260     }
261 }
262
263 void NeighborhoodSearchTestData::computeReferencesInternal(t_pbc* pbc, bool bXY)
264 {
265     real cutoff = cutoff_;
266     if (cutoff <= 0)
267     {
268         cutoff = std::numeric_limits<real>::max();
269     }
270     for (TestPosition& testPos : testPositions_)
271     {
272         testPos.refMinDist      = cutoff;
273         testPos.refNearestPoint = -1;
274         testPos.refPairs.clear();
275         for (int j = 0; j < refPosCount_; ++j)
276         {
277             rvec dx;
278             if (pbc != nullptr)
279             {
280                 pbc_dx(pbc, testPos.x, refPos_[j], dx);
281             }
282             else
283             {
284                 rvec_sub(testPos.x, refPos_[j], dx);
285             }
286             // TODO: This may not work intuitively for 2D with the third box
287             // vector not parallel to the Z axis, but neither does the actual
288             // neighborhood search.
289             const real dist = !bXY ? norm(dx) : std::hypot(dx[XX], dx[YY]);
290             if (dist < testPos.refMinDist)
291             {
292                 testPos.refMinDist      = dist;
293                 testPos.refNearestPoint = j;
294             }
295             if (dist > 0 && dist <= cutoff)
296             {
297                 RefPair pair(j, dist);
298                 GMX_RELEASE_ASSERT(testPos.refPairs.empty() || testPos.refPairs.back() < pair,
299                                    "Reference pairs should be generated in sorted order");
300                 testPos.refPairs.push_back(pair);
301             }
302         }
303     }
304 }
305
306 /********************************************************************
307  * ExclusionsHelper
308  */
309
310 class ExclusionsHelper
311 {
312 public:
313     static void markExcludedPairs(RefPairList* refPairs, int testIndex, const gmx::ListOfLists<int>* excls);
314
315     ExclusionsHelper(int refPosCount, int testPosCount);
316
317     void generateExclusions();
318
319     const gmx::ListOfLists<int>* exclusions() const { return &excls_; }
320
321     gmx::ArrayRef<const int> refPosIds() const
322     {
323         return gmx::makeArrayRef(exclusionIds_).subArray(0, refPosCount_);
324     }
325     gmx::ArrayRef<const int> testPosIds() const
326     {
327         return gmx::makeArrayRef(exclusionIds_).subArray(0, testPosCount_);
328     }
329
330 private:
331     int                   refPosCount_;
332     int                   testPosCount_;
333     std::vector<int>      exclusionIds_;
334     gmx::ListOfLists<int> excls_;
335 };
336
337 // static
338 void ExclusionsHelper::markExcludedPairs(RefPairList* refPairs, int testIndex, const gmx::ListOfLists<int>* excls)
339 {
340     int count = 0;
341     for (const int excludedIndex : (*excls)[testIndex])
342     {
343         NeighborhoodSearchTestData::RefPair searchPair(excludedIndex, 0.0);
344         RefPairList::iterator               excludedRefPair =
345                 std::lower_bound(refPairs->begin(), refPairs->end(), searchPair);
346         if (excludedRefPair != refPairs->end() && excludedRefPair->refIndex == excludedIndex)
347         {
348             excludedRefPair->bFound    = true;
349             excludedRefPair->bExcluded = true;
350             ++count;
351         }
352     }
353 }
354
355 ExclusionsHelper::ExclusionsHelper(int refPosCount, int testPosCount) :
356     refPosCount_(refPosCount),
357     testPosCount_(testPosCount)
358 {
359     // Generate an array of 0, 1, 2, ...
360     // TODO: Make the tests work also with non-trivial exclusion IDs,
361     // and test that.
362     exclusionIds_.resize(std::max(refPosCount, testPosCount), 1);
363     exclusionIds_[0] = 0;
364     std::partial_sum(exclusionIds_.begin(), exclusionIds_.end(), exclusionIds_.begin());
365 }
366
367 void ExclusionsHelper::generateExclusions()
368 {
369     // TODO: Consider a better set of test data, where the density of the
370     // particles would be higher, or where the exclusions would not be random,
371     // to make a higher percentage of the exclusions to actually be within the
372     // cutoff.
373     std::vector<int> exclusionsForAtom;
374     for (int i = 0; i < testPosCount_; ++i)
375     {
376         exclusionsForAtom.clear();
377         for (int j = 0; j < 20; ++j)
378         {
379             exclusionsForAtom.push_back(i + j * 3);
380         }
381         excls_.pushBack(exclusionsForAtom);
382     }
383 }
384
385 /********************************************************************
386  * NeighborhoodSearchTest
387  */
388
389 class NeighborhoodSearchTest : public ::testing::Test
390 {
391 public:
392     void testIsWithin(gmx::AnalysisNeighborhoodSearch* search, const NeighborhoodSearchTestData& data);
393     void testMinimumDistance(gmx::AnalysisNeighborhoodSearch*  search,
394                              const NeighborhoodSearchTestData& data);
395     void testNearestPoint(gmx::AnalysisNeighborhoodSearch* search, const NeighborhoodSearchTestData& data);
396     void testPairSearch(gmx::AnalysisNeighborhoodSearch* search, const NeighborhoodSearchTestData& data);
397     void testPairSearchIndexed(gmx::AnalysisNeighborhood*        nb,
398                                const NeighborhoodSearchTestData& data,
399                                uint64_t                          seed);
400     void testPairSearchFull(gmx::AnalysisNeighborhoodSearch*          search,
401                             const NeighborhoodSearchTestData&         data,
402                             const gmx::AnalysisNeighborhoodPositions& pos,
403                             const gmx::ListOfLists<int>*              excls,
404                             const gmx::ArrayRef<const int>&           refIndices,
405                             const gmx::ArrayRef<const int>&           testIndices,
406                             bool                                      selfPairs);
407
408     gmx::AnalysisNeighborhood nb_;
409 };
410
411 void NeighborhoodSearchTest::testIsWithin(gmx::AnalysisNeighborhoodSearch*  search,
412                                           const NeighborhoodSearchTestData& data)
413 {
414     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
415     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
416     {
417         const bool bWithin = (i->refMinDist <= data.cutoff_);
418         EXPECT_EQ(bWithin, search->isWithin(i->x)) << "Distance is " << i->refMinDist;
419     }
420 }
421
422 void NeighborhoodSearchTest::testMinimumDistance(gmx::AnalysisNeighborhoodSearch*  search,
423                                                  const NeighborhoodSearchTestData& data)
424 {
425     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
426
427     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
428     {
429         const real refDist = i->refMinDist;
430         EXPECT_REAL_EQ_TOL(refDist, search->minimumDistance(i->x), data.relativeTolerance());
431     }
432 }
433
434 void NeighborhoodSearchTest::testNearestPoint(gmx::AnalysisNeighborhoodSearch*  search,
435                                               const NeighborhoodSearchTestData& data)
436 {
437     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
438     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
439     {
440         const gmx::AnalysisNeighborhoodPair pair = search->nearestPoint(i->x);
441         if (pair.isValid())
442         {
443             EXPECT_EQ(i->refNearestPoint, pair.refIndex());
444             EXPECT_EQ(0, pair.testIndex());
445             EXPECT_REAL_EQ_TOL(i->refMinDist, std::sqrt(pair.distance2()), data.relativeTolerance());
446         }
447         else
448         {
449             EXPECT_EQ(i->refNearestPoint, -1);
450         }
451     }
452 }
453
454 //! Helper function for formatting test failure messages.
455 std::string formatVector(const rvec x)
456 {
457     return gmx::formatString("[%.3f, %.3f, %.3f]", x[XX], x[YY], x[ZZ]);
458 }
459
460 /*! \brief
461  * Helper function to check that all expected pairs were found.
462  */
463 void checkAllPairsFound(const RefPairList&            refPairs,
464                         const std::vector<gmx::RVec>& refPos,
465                         int                           testPosIndex,
466                         const rvec                    testPos)
467 {
468     // This could be elegantly expressed with Google Mock matchers, but that
469     // has a significant effect on the runtime of the tests...
470     int                         count = 0;
471     RefPairList::const_iterator first;
472     for (RefPairList::const_iterator i = refPairs.begin(); i != refPairs.end(); ++i)
473     {
474         if (!i->bFound)
475         {
476             ++count;
477             first = i;
478         }
479     }
480     if (count > 0)
481     {
482         ADD_FAILURE() << "Some pairs (" << count << "/" << refPairs.size() << ") "
483                       << "within the cutoff were not found. First pair:\n"
484                       << " Ref: " << first->refIndex << " at "
485                       << formatVector(refPos[first->refIndex]) << "\n"
486                       << "Test: " << testPosIndex << " at " << formatVector(testPos) << "\n"
487                       << "Dist: " << first->distance;
488     }
489 }
490
491 void NeighborhoodSearchTest::testPairSearch(gmx::AnalysisNeighborhoodSearch*  search,
492                                             const NeighborhoodSearchTestData& data)
493 {
494     testPairSearchFull(search, data, data.testPositions(), nullptr, {}, {}, false);
495 }
496
497 void NeighborhoodSearchTest::testPairSearchIndexed(gmx::AnalysisNeighborhood*        nb,
498                                                    const NeighborhoodSearchTestData& data,
499                                                    uint64_t                          seed)
500 {
501     std::vector<int> refIndices(data.generateIndex(data.refPos_.size(), seed++));
502     std::vector<int> testIndices(data.generateIndex(data.testPositions_.size(), seed++));
503     gmx::AnalysisNeighborhoodSearch search =
504             nb->initSearch(&data.pbc_, data.refPositions().indexed(refIndices));
505     testPairSearchFull(&search, data, data.testPositions(), nullptr, refIndices, testIndices, false);
506 }
507
508 void NeighborhoodSearchTest::testPairSearchFull(gmx::AnalysisNeighborhoodSearch*          search,
509                                                 const NeighborhoodSearchTestData&         data,
510                                                 const gmx::AnalysisNeighborhoodPositions& pos,
511                                                 const gmx::ListOfLists<int>*              excls,
512                                                 const gmx::ArrayRef<const int>& refIndices,
513                                                 const gmx::ArrayRef<const int>& testIndices,
514                                                 bool                            selfPairs)
515 {
516     std::map<int, RefPairList> refPairs;
517     // TODO: Some parts of this code do not work properly if pos does not
518     // initially contain all the test positions.
519     if (testIndices.empty())
520     {
521         for (size_t i = 0; i < data.testPositions_.size(); ++i)
522         {
523             refPairs[i] = data.testPositions_[i].refPairs;
524         }
525     }
526     else
527     {
528         for (int index : testIndices)
529         {
530             refPairs[index] = data.testPositions_[index].refPairs;
531         }
532     }
533     if (excls != nullptr)
534     {
535         GMX_RELEASE_ASSERT(!selfPairs, "Self-pairs testing not implemented with exclusions");
536         for (auto& entry : refPairs)
537         {
538             const int testIndex = entry.first;
539             ExclusionsHelper::markExcludedPairs(&entry.second, testIndex, excls);
540         }
541     }
542     if (!refIndices.empty())
543     {
544         GMX_RELEASE_ASSERT(!selfPairs, "Self-pairs testing not implemented with indexing");
545         for (auto& entry : refPairs)
546         {
547             for (auto& refPair : entry.second)
548             {
549                 refPair.bIndexed = false;
550             }
551             for (int index : refIndices)
552             {
553                 NeighborhoodSearchTestData::RefPair searchPair(index, 0.0);
554                 auto refPair = std::lower_bound(entry.second.begin(), entry.second.end(), searchPair);
555                 if (refPair != entry.second.end() && refPair->refIndex == index)
556                 {
557                     refPair->bIndexed = true;
558                 }
559             }
560             for (auto& refPair : entry.second)
561             {
562                 if (!refPair.bIndexed)
563                 {
564                     refPair.bFound = true;
565                 }
566             }
567         }
568     }
569
570     gmx::AnalysisNeighborhoodPositions posCopy(pos);
571     if (!testIndices.empty())
572     {
573         posCopy.indexed(testIndices);
574     }
575     gmx::AnalysisNeighborhoodPairSearch pairSearch =
576             selfPairs ? search->startSelfPairSearch() : search->startPairSearch(posCopy);
577     gmx::AnalysisNeighborhoodPair pair;
578     while (pairSearch.findNextPair(&pair))
579     {
580         const int testIndex = (testIndices.empty() ? pair.testIndex() : testIndices[pair.testIndex()]);
581         const int refIndex = (refIndices.empty() ? pair.refIndex() : refIndices[pair.refIndex()]);
582
583         if (refPairs.count(testIndex) == 0)
584         {
585             ADD_FAILURE() << "Expected: No pairs are returned for test position " << testIndex << ".\n"
586                           << "  Actual: Pair with ref " << refIndex << " is returned.";
587             continue;
588         }
589         NeighborhoodSearchTestData::RefPair searchPair(refIndex, std::sqrt(pair.distance2()));
590         const auto                          foundRefPair =
591                 std::lower_bound(refPairs[testIndex].begin(), refPairs[testIndex].end(), searchPair);
592         if (foundRefPair == refPairs[testIndex].end() || foundRefPair->refIndex != refIndex)
593         {
594             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
595                           << ") is not within the cutoff.\n"
596                           << "  Actual: It is returned.";
597         }
598         else if (foundRefPair->bExcluded)
599         {
600             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
601                           << ") is excluded from the search.\n"
602                           << "  Actual: It is returned.";
603         }
604         else if (!foundRefPair->bIndexed)
605         {
606             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
607                           << ") is not part of the indexed set.\n"
608                           << "  Actual: It is returned.";
609         }
610         else if (foundRefPair->bFound)
611         {
612             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
613                           << ") is returned only once.\n"
614                           << "  Actual: It is returned multiple times.";
615             return;
616         }
617         else
618         {
619             foundRefPair->bFound = true;
620
621             EXPECT_REAL_EQ_TOL(foundRefPair->distance, searchPair.distance, data.relativeTolerance())
622                     << "Distance computed by the neighborhood search does not match.";
623             if (selfPairs)
624             {
625                 searchPair              = NeighborhoodSearchTestData::RefPair(testIndex, 0.0);
626                 const auto otherRefPair = std::lower_bound(refPairs[refIndex].begin(),
627                                                            refPairs[refIndex].end(), searchPair);
628                 GMX_RELEASE_ASSERT(otherRefPair != refPairs[refIndex].end(),
629                                    "Precomputed reference data is not symmetric");
630                 otherRefPair->bFound = true;
631             }
632         }
633     }
634
635     for (auto& entry : refPairs)
636     {
637         const int testIndex = entry.first;
638         checkAllPairsFound(entry.second, data.refPos_, testIndex, data.testPositions_[testIndex].x);
639     }
640 }
641
642 /********************************************************************
643  * Test data generation
644  */
645
646 class TrivialTestData
647 {
648 public:
649     static const NeighborhoodSearchTestData& get()
650     {
651         static TrivialTestData singleton;
652         return singleton.data_;
653     }
654
655     TrivialTestData() : data_(12345, 1.0)
656     {
657         // Make the box so small we are virtually guaranteed to have
658         // several neighbors for the five test positions
659         data_.box_[XX][XX] = 3.0;
660         data_.box_[YY][YY] = 3.0;
661         data_.box_[ZZ][ZZ] = 3.0;
662         data_.generateRandomRefPositions(10);
663         data_.generateRandomTestPositions(5);
664         set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
665         data_.computeReferences(&data_.pbc_);
666     }
667
668 private:
669     NeighborhoodSearchTestData data_;
670 };
671
672 class TrivialSelfPairsTestData
673 {
674 public:
675     static const NeighborhoodSearchTestData& get()
676     {
677         static TrivialSelfPairsTestData singleton;
678         return singleton.data_;
679     }
680
681     TrivialSelfPairsTestData() : data_(12345, 1.0)
682     {
683         data_.box_[XX][XX] = 3.0;
684         data_.box_[YY][YY] = 3.0;
685         data_.box_[ZZ][ZZ] = 3.0;
686         data_.generateRandomRefPositions(20);
687         data_.useRefPositionsAsTestPositions();
688         set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
689         data_.computeReferences(&data_.pbc_);
690     }
691
692 private:
693     NeighborhoodSearchTestData data_;
694 };
695
696 class TrivialNoPBCTestData
697 {
698 public:
699     static const NeighborhoodSearchTestData& get()
700     {
701         static TrivialNoPBCTestData singleton;
702         return singleton.data_;
703     }
704
705     TrivialNoPBCTestData() : data_(12345, 1.0)
706     {
707         data_.generateRandomRefPositions(10);
708         data_.generateRandomTestPositions(5);
709         data_.computeReferences(nullptr);
710     }
711
712 private:
713     NeighborhoodSearchTestData data_;
714 };
715
716 class RandomBoxFullPBCData
717 {
718 public:
719     static const NeighborhoodSearchTestData& get()
720     {
721         static RandomBoxFullPBCData singleton;
722         return singleton.data_;
723     }
724
725     RandomBoxFullPBCData() : data_(12345, 1.0)
726     {
727         data_.box_[XX][XX] = 10.0;
728         data_.box_[YY][YY] = 5.0;
729         data_.box_[ZZ][ZZ] = 7.0;
730         // TODO: Consider whether manually picking some positions would give better
731         // test coverage.
732         data_.generateRandomRefPositions(1000);
733         data_.generateRandomTestPositions(100);
734         set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
735         data_.computeReferences(&data_.pbc_);
736     }
737
738 private:
739     NeighborhoodSearchTestData data_;
740 };
741
742 class RandomBoxSelfPairsData
743 {
744 public:
745     static const NeighborhoodSearchTestData& get()
746     {
747         static RandomBoxSelfPairsData singleton;
748         return singleton.data_;
749     }
750
751     RandomBoxSelfPairsData() : data_(12345, 1.0)
752     {
753         data_.box_[XX][XX] = 10.0;
754         data_.box_[YY][YY] = 5.0;
755         data_.box_[ZZ][ZZ] = 7.0;
756         data_.generateRandomRefPositions(1000);
757         data_.useRefPositionsAsTestPositions();
758         set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
759         data_.computeReferences(&data_.pbc_);
760     }
761
762 private:
763     NeighborhoodSearchTestData data_;
764 };
765
766 class RandomBoxXYFullPBCData
767 {
768 public:
769     static const NeighborhoodSearchTestData& get()
770     {
771         static RandomBoxXYFullPBCData singleton;
772         return singleton.data_;
773     }
774
775     RandomBoxXYFullPBCData() : data_(54321, 1.0)
776     {
777         data_.box_[XX][XX] = 10.0;
778         data_.box_[YY][YY] = 5.0;
779         data_.box_[ZZ][ZZ] = 7.0;
780         // TODO: Consider whether manually picking some positions would give better
781         // test coverage.
782         data_.generateRandomRefPositions(1000);
783         data_.generateRandomTestPositions(100);
784         set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
785         data_.computeReferencesXY(&data_.pbc_);
786     }
787
788 private:
789     NeighborhoodSearchTestData data_;
790 };
791
792 class RandomTriclinicFullPBCData
793 {
794 public:
795     static const NeighborhoodSearchTestData& get()
796     {
797         static RandomTriclinicFullPBCData singleton;
798         return singleton.data_;
799     }
800
801     RandomTriclinicFullPBCData() : data_(12345, 1.0)
802     {
803         data_.box_[XX][XX] = 5.0;
804         data_.box_[YY][XX] = 2.5;
805         data_.box_[YY][YY] = 2.5 * std::sqrt(3.0);
806         data_.box_[ZZ][XX] = 2.5;
807         data_.box_[ZZ][YY] = 2.5 * std::sqrt(1.0 / 3.0);
808         data_.box_[ZZ][ZZ] = 5.0 * std::sqrt(2.0 / 3.0);
809         // TODO: Consider whether manually picking some positions would give better
810         // test coverage.
811         data_.generateRandomRefPositions(1000);
812         data_.generateRandomTestPositions(100);
813         set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
814         data_.computeReferences(&data_.pbc_);
815     }
816
817 private:
818     NeighborhoodSearchTestData data_;
819 };
820
821 class RandomBox2DPBCData
822 {
823 public:
824     static const NeighborhoodSearchTestData& get()
825     {
826         static RandomBox2DPBCData singleton;
827         return singleton.data_;
828     }
829
830     RandomBox2DPBCData() : data_(12345, 1.0)
831     {
832         data_.box_[XX][XX] = 10.0;
833         data_.box_[YY][YY] = 7.0;
834         data_.box_[ZZ][ZZ] = 5.0;
835         // TODO: Consider whether manually picking some positions would give better
836         // test coverage.
837         data_.generateRandomRefPositions(1000);
838         data_.generateRandomTestPositions(100);
839         set_pbc(&data_.pbc_, epbcXY, data_.box_);
840         data_.computeReferences(&data_.pbc_);
841     }
842
843 private:
844     NeighborhoodSearchTestData data_;
845 };
846
847 class RandomBoxNoPBCData
848 {
849 public:
850     static const NeighborhoodSearchTestData& get()
851     {
852         static RandomBoxNoPBCData singleton;
853         return singleton.data_;
854     }
855
856     RandomBoxNoPBCData() : data_(12345, 1.0)
857     {
858         data_.box_[XX][XX] = 10.0;
859         data_.box_[YY][YY] = 5.0;
860         data_.box_[ZZ][ZZ] = 7.0;
861         // TODO: Consider whether manually picking some positions would give better
862         // test coverage.
863         data_.generateRandomRefPositions(1000);
864         data_.generateRandomTestPositions(100);
865         set_pbc(&data_.pbc_, epbcNONE, data_.box_);
866         data_.computeReferences(nullptr);
867     }
868
869 private:
870     NeighborhoodSearchTestData data_;
871 };
872
873 /********************************************************************
874  * Actual tests
875  */
876
877 TEST_F(NeighborhoodSearchTest, SimpleSearch)
878 {
879     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
880
881     nb_.setCutoff(data.cutoff_);
882     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
883     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
884     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
885
886     testIsWithin(&search, data);
887     testMinimumDistance(&search, data);
888     testNearestPoint(&search, data);
889     testPairSearch(&search, data);
890
891     search.reset();
892     testPairSearchIndexed(&nb_, data, 123);
893 }
894
895 TEST_F(NeighborhoodSearchTest, SimpleSearchXY)
896 {
897     const NeighborhoodSearchTestData& data = RandomBoxXYFullPBCData::get();
898
899     nb_.setCutoff(data.cutoff_);
900     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
901     nb_.setXYMode(true);
902     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
903     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
904
905     testIsWithin(&search, data);
906     testMinimumDistance(&search, data);
907     testNearestPoint(&search, data);
908     testPairSearch(&search, data);
909 }
910
911 TEST_F(NeighborhoodSearchTest, GridSearchBox)
912 {
913     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
914
915     nb_.setCutoff(data.cutoff_);
916     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
917     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
918     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
919
920     testIsWithin(&search, data);
921     testMinimumDistance(&search, data);
922     testNearestPoint(&search, data);
923     testPairSearch(&search, data);
924
925     search.reset();
926     testPairSearchIndexed(&nb_, data, 456);
927 }
928
929 TEST_F(NeighborhoodSearchTest, GridSearchTriclinic)
930 {
931     const NeighborhoodSearchTestData& data = RandomTriclinicFullPBCData::get();
932
933     nb_.setCutoff(data.cutoff_);
934     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
935     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
936     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
937
938     testPairSearch(&search, data);
939 }
940
941 TEST_F(NeighborhoodSearchTest, GridSearch2DPBC)
942 {
943     const NeighborhoodSearchTestData& data = RandomBox2DPBCData::get();
944
945     nb_.setCutoff(data.cutoff_);
946     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
947     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
948     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
949
950     testIsWithin(&search, data);
951     testMinimumDistance(&search, data);
952     testNearestPoint(&search, data);
953     testPairSearch(&search, data);
954 }
955
956 TEST_F(NeighborhoodSearchTest, GridSearchNoPBC)
957 {
958     const NeighborhoodSearchTestData& data = RandomBoxNoPBCData::get();
959
960     nb_.setCutoff(data.cutoff_);
961     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
962     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
963     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
964
965     testPairSearch(&search, data);
966 }
967
968 TEST_F(NeighborhoodSearchTest, GridSearchXYBox)
969 {
970     const NeighborhoodSearchTestData& data = RandomBoxXYFullPBCData::get();
971
972     nb_.setCutoff(data.cutoff_);
973     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
974     nb_.setXYMode(true);
975     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
976     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
977
978     testIsWithin(&search, data);
979     testMinimumDistance(&search, data);
980     testNearestPoint(&search, data);
981     testPairSearch(&search, data);
982 }
983
984 TEST_F(NeighborhoodSearchTest, SimpleSelfPairsSearch)
985 {
986     const NeighborhoodSearchTestData& data = TrivialSelfPairsTestData::get();
987
988     nb_.setCutoff(data.cutoff_);
989     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
990     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
991     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
992
993     testPairSearchFull(&search, data, data.testPositions(), nullptr, {}, {}, true);
994 }
995
996 TEST_F(NeighborhoodSearchTest, GridSelfPairsSearch)
997 {
998     const NeighborhoodSearchTestData& data = RandomBoxSelfPairsData::get();
999
1000     nb_.setCutoff(data.cutoff_);
1001     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
1002     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
1003     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
1004
1005     testPairSearchFull(&search, data, data.testPositions(), nullptr, {}, {}, true);
1006 }
1007
1008 TEST_F(NeighborhoodSearchTest, HandlesConcurrentSearches)
1009 {
1010     const NeighborhoodSearchTestData& data = TrivialTestData::get();
1011
1012     nb_.setCutoff(data.cutoff_);
1013     gmx::AnalysisNeighborhoodSearch search1 = nb_.initSearch(&data.pbc_, data.refPositions());
1014     gmx::AnalysisNeighborhoodSearch search2 = nb_.initSearch(&data.pbc_, data.refPositions());
1015
1016     // These checks are fragile, and unfortunately depend on the random
1017     // engine used to create the test positions. There is no particular reason
1018     // why exactly particles 0 & 2 should have neighbors, but in this case they do.
1019     gmx::AnalysisNeighborhoodPairSearch pairSearch1 = search1.startPairSearch(data.testPosition(0));
1020     gmx::AnalysisNeighborhoodPairSearch pairSearch2 = search1.startPairSearch(data.testPosition(2));
1021
1022     testPairSearch(&search2, data);
1023
1024     gmx::AnalysisNeighborhoodPair pair;
1025     ASSERT_TRUE(pairSearch1.findNextPair(&pair))
1026             << "Test data did not contain any pairs for position 0 (problem in the test).";
1027     EXPECT_EQ(0, pair.testIndex());
1028     {
1029         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), std::sqrt(pair.distance2()));
1030         EXPECT_TRUE(data.containsPair(0, searchPair));
1031     }
1032
1033     ASSERT_TRUE(pairSearch2.findNextPair(&pair))
1034             << "Test data did not contain any pairs for position 2 (problem in the test).";
1035     EXPECT_EQ(2, pair.testIndex());
1036     {
1037         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), std::sqrt(pair.distance2()));
1038         EXPECT_TRUE(data.containsPair(2, searchPair));
1039     }
1040 }
1041
1042 TEST_F(NeighborhoodSearchTest, HandlesNoPBC)
1043 {
1044     const NeighborhoodSearchTestData& data = TrivialNoPBCTestData::get();
1045
1046     nb_.setCutoff(data.cutoff_);
1047     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
1048     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
1049
1050     testIsWithin(&search, data);
1051     testMinimumDistance(&search, data);
1052     testNearestPoint(&search, data);
1053     testPairSearch(&search, data);
1054 }
1055
1056 TEST_F(NeighborhoodSearchTest, HandlesNullPBC)
1057 {
1058     const NeighborhoodSearchTestData& data = TrivialNoPBCTestData::get();
1059
1060     nb_.setCutoff(data.cutoff_);
1061     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(nullptr, data.refPositions());
1062     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
1063
1064     testIsWithin(&search, data);
1065     testMinimumDistance(&search, data);
1066     testNearestPoint(&search, data);
1067     testPairSearch(&search, data);
1068 }
1069
1070 TEST_F(NeighborhoodSearchTest, HandlesSkippingPairs)
1071 {
1072     const NeighborhoodSearchTestData& data = TrivialTestData::get();
1073
1074     nb_.setCutoff(data.cutoff_);
1075     gmx::AnalysisNeighborhoodSearch     search = nb_.initSearch(&data.pbc_, data.refPositions());
1076     gmx::AnalysisNeighborhoodPairSearch pairSearch = search.startPairSearch(data.testPositions());
1077     gmx::AnalysisNeighborhoodPair       pair;
1078     // TODO: This test needs to be adjusted if the grid search gets optimized
1079     // to loop over the test positions in cell order (first, the ordering
1080     // assumption here breaks, and second, it then needs to be tested
1081     // separately for simple and grid searches).
1082     int currentIndex = 0;
1083     while (pairSearch.findNextPair(&pair))
1084     {
1085         while (currentIndex < pair.testIndex())
1086         {
1087             ++currentIndex;
1088         }
1089         EXPECT_EQ(currentIndex, pair.testIndex());
1090         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), std::sqrt(pair.distance2()));
1091         EXPECT_TRUE(data.containsPair(currentIndex, searchPair));
1092         pairSearch.skipRemainingPairsForTestPosition();
1093         ++currentIndex;
1094     }
1095 }
1096
1097 TEST_F(NeighborhoodSearchTest, SimpleSearchExclusions)
1098 {
1099     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
1100
1101     ExclusionsHelper helper(data.refPosCount_, data.testPositions_.size());
1102     helper.generateExclusions();
1103
1104     nb_.setCutoff(data.cutoff_);
1105     nb_.setTopologyExclusions(helper.exclusions());
1106     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
1107     gmx::AnalysisNeighborhoodSearch search =
1108             nb_.initSearch(&data.pbc_, data.refPositions().exclusionIds(helper.refPosIds()));
1109     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
1110
1111     testPairSearchFull(&search, data, data.testPositions().exclusionIds(helper.testPosIds()),
1112                        helper.exclusions(), {}, {}, false);
1113 }
1114
1115 TEST_F(NeighborhoodSearchTest, GridSearchExclusions)
1116 {
1117     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
1118
1119     ExclusionsHelper helper(data.refPosCount_, data.testPositions_.size());
1120     helper.generateExclusions();
1121
1122     nb_.setCutoff(data.cutoff_);
1123     nb_.setTopologyExclusions(helper.exclusions());
1124     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
1125     gmx::AnalysisNeighborhoodSearch search =
1126             nb_.initSearch(&data.pbc_, data.refPositions().exclusionIds(helper.refPosIds()));
1127     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
1128
1129     testPairSearchFull(&search, data, data.testPositions().exclusionIds(helper.testPosIds()),
1130                        helper.exclusions(), {}, {}, false);
1131 }
1132
1133 } // namespace