e32d8495099ab248196ae3c4336f016710f0e00b
[alexxy/gromacs.git] / src / gromacs / selection / tests / nbsearch.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2013,2014,2015,2016,2017,2018,2019, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  * \brief
37  * Tests selection neighborhood searching.
38  *
39  * \todo
40  * Increase coverage of these tests for different corner cases: other PBC cases
41  * than full 3D, large cutoffs (larger than half the box size), etc.
42  * At least some of these probably don't work correctly.
43  *
44  * \author Teemu Murtola <teemu.murtola@gmail.com>
45  * \ingroup module_selection
46  */
47 #include "gmxpre.h"
48
49 #include "gromacs/selection/nbsearch.h"
50
51 #include <cmath>
52
53 #include <algorithm>
54 #include <limits>
55 #include <map>
56 #include <numeric>
57 #include <vector>
58
59 #include <gtest/gtest.h>
60
61 #include "gromacs/math/functions.h"
62 #include "gromacs/math/vec.h"
63 #include "gromacs/pbcutil/pbc.h"
64 #include "gromacs/random/threefry.h"
65 #include "gromacs/random/uniformrealdistribution.h"
66 #include "gromacs/topology/block.h"
67 #include "gromacs/utility/listoflists.h"
68 #include "gromacs/utility/smalloc.h"
69 #include "gromacs/utility/stringutil.h"
70
71 #include "testutils/testasserts.h"
72
73
74 namespace
75 {
76
77 /********************************************************************
78  * NeighborhoodSearchTestData
79  */
80
81 class NeighborhoodSearchTestData
82 {
83 public:
84     struct RefPair
85     {
86         RefPair(int refIndex, real distance) :
87             refIndex(refIndex),
88             distance(distance),
89             bFound(false),
90             bExcluded(false),
91             bIndexed(true)
92         {
93         }
94
95         bool operator<(const RefPair& other) const { return refIndex < other.refIndex; }
96
97         int  refIndex;
98         real distance;
99         // The variables below are state variables that are only used
100         // during the actual testing after creating a copy of the reference
101         // pair list, not as part of the reference data.
102         // Simpler to have just a single structure for both purposes.
103         bool bFound;
104         bool bExcluded;
105         bool bIndexed;
106     };
107
108     struct TestPosition
109     {
110         explicit TestPosition(const rvec x) : refMinDist(0.0), refNearestPoint(-1)
111         {
112             copy_rvec(x, this->x);
113         }
114
115         rvec                 x;
116         real                 refMinDist;
117         int                  refNearestPoint;
118         std::vector<RefPair> refPairs;
119     };
120
121     typedef std::vector<TestPosition> TestPositionList;
122
123     NeighborhoodSearchTestData(uint64_t seed, real cutoff);
124
125     gmx::AnalysisNeighborhoodPositions refPositions() const
126     {
127         return gmx::AnalysisNeighborhoodPositions(refPos_);
128     }
129     gmx::AnalysisNeighborhoodPositions testPositions() const
130     {
131         if (testPos_.empty())
132         {
133             testPos_.reserve(testPositions_.size());
134             for (size_t i = 0; i < testPositions_.size(); ++i)
135             {
136                 testPos_.emplace_back(testPositions_[i].x);
137             }
138         }
139         return gmx::AnalysisNeighborhoodPositions(testPos_);
140     }
141     gmx::AnalysisNeighborhoodPositions testPosition(int index) const
142     {
143         return testPositions().selectSingleFromArray(index);
144     }
145
146     void addTestPosition(const rvec x)
147     {
148         GMX_RELEASE_ASSERT(testPos_.empty(), "Cannot add positions after testPositions() call");
149         testPositions_.emplace_back(x);
150     }
151     gmx::RVec        generateRandomPosition();
152     std::vector<int> generateIndex(int count, uint64_t seed) const;
153     void             generateRandomRefPositions(int count);
154     void             generateRandomTestPositions(int count);
155     void             useRefPositionsAsTestPositions();
156     void             computeReferences(t_pbc* pbc) { computeReferencesInternal(pbc, false); }
157     void             computeReferencesXY(t_pbc* pbc) { computeReferencesInternal(pbc, true); }
158
159     bool containsPair(int testIndex, const RefPair& pair) const
160     {
161         const std::vector<RefPair>&          refPairs = testPositions_[testIndex].refPairs;
162         std::vector<RefPair>::const_iterator foundRefPair =
163                 std::lower_bound(refPairs.begin(), refPairs.end(), pair);
164         return !(foundRefPair == refPairs.end() || foundRefPair->refIndex != pair.refIndex);
165     }
166
167     // Return a tolerance that accounts for the magnitudes of the coordinates
168     // when doing subtractions, so that we set the ULP tolerance relative to the
169     // coordinate values rather than their difference.
170     // i.e., 10.0-9.9999999 will achieve a few ULP accuracy relative
171     // to 10.0, but a much larger error relative to the difference.
172     gmx::test::FloatingPointTolerance relativeTolerance() const
173     {
174         real magnitude = std::max(box_[XX][XX], std::max(box_[YY][YY], box_[ZZ][ZZ]));
175         return gmx::test::relativeToleranceAsUlp(magnitude, 4);
176     }
177
178     gmx::DefaultRandomEngine rng_;
179     real                     cutoff_;
180     matrix                   box_;
181     t_pbc                    pbc_;
182     int                      refPosCount_;
183     std::vector<gmx::RVec>   refPos_;
184     TestPositionList         testPositions_;
185
186 private:
187     void computeReferencesInternal(t_pbc* pbc, bool bXY);
188
189     mutable std::vector<gmx::RVec> testPos_;
190 };
191
192 //! Shorthand for a collection of reference pairs.
193 typedef std::vector<NeighborhoodSearchTestData::RefPair> RefPairList;
194
195 NeighborhoodSearchTestData::NeighborhoodSearchTestData(uint64_t seed, real cutoff) :
196     rng_(seed),
197     cutoff_(cutoff),
198     refPosCount_(0)
199 {
200     clear_mat(box_);
201     set_pbc(&pbc_, epbcNONE, box_);
202 }
203
204 gmx::RVec NeighborhoodSearchTestData::generateRandomPosition()
205 {
206     gmx::UniformRealDistribution<real> dist;
207     rvec                               fx, x;
208     fx[XX] = dist(rng_);
209     fx[YY] = dist(rng_);
210     fx[ZZ] = dist(rng_);
211     mvmul(box_, fx, x);
212     // Add a small displacement to allow positions outside the box
213     x[XX] += 0.2 * dist(rng_) - 0.1;
214     x[YY] += 0.2 * dist(rng_) - 0.1;
215     x[ZZ] += 0.2 * dist(rng_) - 0.1;
216     return x;
217 }
218
219 std::vector<int> NeighborhoodSearchTestData::generateIndex(int count, uint64_t seed) const
220 {
221     gmx::DefaultRandomEngine           rngIndex(seed);
222     gmx::UniformRealDistribution<real> dist;
223     std::vector<int>                   result;
224
225     for (int i = 0; i < count; ++i)
226     {
227         if (dist(rngIndex) > 0.5)
228         {
229             result.push_back(i);
230         }
231     }
232     return result;
233 }
234
235 void NeighborhoodSearchTestData::generateRandomRefPositions(int count)
236 {
237     refPosCount_ = count;
238     refPos_.reserve(count);
239     for (int i = 0; i < count; ++i)
240     {
241         refPos_.push_back(generateRandomPosition());
242     }
243 }
244
245 void NeighborhoodSearchTestData::generateRandomTestPositions(int count)
246 {
247     testPositions_.reserve(count);
248     for (int i = 0; i < count; ++i)
249     {
250         addTestPosition(generateRandomPosition());
251     }
252 }
253
254 void NeighborhoodSearchTestData::useRefPositionsAsTestPositions()
255 {
256     testPositions_.reserve(refPosCount_);
257     for (const auto& refPos : refPos_)
258     {
259         addTestPosition(refPos);
260     }
261 }
262
263 void NeighborhoodSearchTestData::computeReferencesInternal(t_pbc* pbc, bool bXY)
264 {
265     real cutoff = cutoff_;
266     if (cutoff <= 0)
267     {
268         cutoff = std::numeric_limits<real>::max();
269     }
270     for (TestPosition& testPos : testPositions_)
271     {
272         testPos.refMinDist      = cutoff;
273         testPos.refNearestPoint = -1;
274         testPos.refPairs.clear();
275         for (int j = 0; j < refPosCount_; ++j)
276         {
277             rvec dx;
278             if (pbc != nullptr)
279             {
280                 pbc_dx(pbc, testPos.x, refPos_[j], dx);
281             }
282             else
283             {
284                 rvec_sub(testPos.x, refPos_[j], dx);
285             }
286             // TODO: This may not work intuitively for 2D with the third box
287             // vector not parallel to the Z axis, but neither does the actual
288             // neighborhood search.
289             const real dist = !bXY ? norm(dx) : std::hypot(dx[XX], dx[YY]);
290             if (dist < testPos.refMinDist)
291             {
292                 testPos.refMinDist      = dist;
293                 testPos.refNearestPoint = j;
294             }
295             if (dist > 0 && dist <= cutoff)
296             {
297                 RefPair pair(j, dist);
298                 GMX_RELEASE_ASSERT(testPos.refPairs.empty() || testPos.refPairs.back() < pair,
299                                    "Reference pairs should be generated in sorted order");
300                 testPos.refPairs.push_back(pair);
301             }
302         }
303     }
304 }
305
306 /********************************************************************
307  * ExclusionsHelper
308  */
309
310 class ExclusionsHelper
311 {
312 public:
313     static void markExcludedPairs(RefPairList* refPairs, int testIndex, const gmx::ListOfLists<int>* excls);
314
315     ExclusionsHelper(int refPosCount, int testPosCount);
316
317     void generateExclusions();
318
319     const gmx::ListOfLists<int>* exclusions() const { return &excls_; }
320
321     gmx::ArrayRef<const int> refPosIds() const
322     {
323         return gmx::makeArrayRef(exclusionIds_).subArray(0, refPosCount_);
324     }
325     gmx::ArrayRef<const int> testPosIds() const
326     {
327         return gmx::makeArrayRef(exclusionIds_).subArray(0, testPosCount_);
328     }
329
330 private:
331     int                   refPosCount_;
332     int                   testPosCount_;
333     std::vector<int>      exclusionIds_;
334     gmx::ListOfLists<int> excls_;
335 };
336
337 // static
338 void ExclusionsHelper::markExcludedPairs(RefPairList* refPairs, int testIndex, const gmx::ListOfLists<int>* excls)
339 {
340     int count = 0;
341     for (const int excludedIndex : (*excls)[testIndex])
342     {
343         NeighborhoodSearchTestData::RefPair searchPair(excludedIndex, 0.0);
344         RefPairList::iterator               excludedRefPair =
345                 std::lower_bound(refPairs->begin(), refPairs->end(), searchPair);
346         if (excludedRefPair != refPairs->end() && excludedRefPair->refIndex == excludedIndex)
347         {
348             excludedRefPair->bFound    = true;
349             excludedRefPair->bExcluded = true;
350             ++count;
351         }
352     }
353 }
354
355 ExclusionsHelper::ExclusionsHelper(int refPosCount, int testPosCount) :
356     refPosCount_(refPosCount),
357     testPosCount_(testPosCount)
358 {
359     // Generate an array of 0, 1, 2, ...
360     // TODO: Make the tests work also with non-trivial exclusion IDs,
361     // and test that.
362     exclusionIds_.resize(std::max(refPosCount, testPosCount), 1);
363     exclusionIds_[0] = 0;
364     std::partial_sum(exclusionIds_.begin(), exclusionIds_.end(), exclusionIds_.begin());
365 }
366
367 void ExclusionsHelper::generateExclusions()
368 {
369     // TODO: Consider a better set of test data, where the density of the
370     // particles would be higher, or where the exclusions would not be random,
371     // to make a higher percentage of the exclusions to actually be within the
372     // cutoff.
373     for (int i = 0; i < testPosCount_; ++i)
374     {
375         excls_.pushBackListOfSize(20);
376         gmx::ArrayRef<int> exclusionsForAtom = excls_.back();
377         for (int j = 0; j < 20; ++j)
378         {
379             exclusionsForAtom[j] = i + j * 3;
380         }
381     }
382 }
383
384 /********************************************************************
385  * NeighborhoodSearchTest
386  */
387
388 class NeighborhoodSearchTest : public ::testing::Test
389 {
390 public:
391     void testIsWithin(gmx::AnalysisNeighborhoodSearch* search, const NeighborhoodSearchTestData& data);
392     void testMinimumDistance(gmx::AnalysisNeighborhoodSearch*  search,
393                              const NeighborhoodSearchTestData& data);
394     void testNearestPoint(gmx::AnalysisNeighborhoodSearch* search, const NeighborhoodSearchTestData& data);
395     void testPairSearch(gmx::AnalysisNeighborhoodSearch* search, const NeighborhoodSearchTestData& data);
396     void testPairSearchIndexed(gmx::AnalysisNeighborhood*        nb,
397                                const NeighborhoodSearchTestData& data,
398                                uint64_t                          seed);
399     void testPairSearchFull(gmx::AnalysisNeighborhoodSearch*          search,
400                             const NeighborhoodSearchTestData&         data,
401                             const gmx::AnalysisNeighborhoodPositions& pos,
402                             const gmx::ListOfLists<int>*              excls,
403                             const gmx::ArrayRef<const int>&           refIndices,
404                             const gmx::ArrayRef<const int>&           testIndices,
405                             bool                                      selfPairs);
406
407     gmx::AnalysisNeighborhood nb_;
408 };
409
410 void NeighborhoodSearchTest::testIsWithin(gmx::AnalysisNeighborhoodSearch*  search,
411                                           const NeighborhoodSearchTestData& data)
412 {
413     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
414     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
415     {
416         const bool bWithin = (i->refMinDist <= data.cutoff_);
417         EXPECT_EQ(bWithin, search->isWithin(i->x)) << "Distance is " << i->refMinDist;
418     }
419 }
420
421 void NeighborhoodSearchTest::testMinimumDistance(gmx::AnalysisNeighborhoodSearch*  search,
422                                                  const NeighborhoodSearchTestData& data)
423 {
424     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
425
426     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
427     {
428         const real refDist = i->refMinDist;
429         EXPECT_REAL_EQ_TOL(refDist, search->minimumDistance(i->x), data.relativeTolerance());
430     }
431 }
432
433 void NeighborhoodSearchTest::testNearestPoint(gmx::AnalysisNeighborhoodSearch*  search,
434                                               const NeighborhoodSearchTestData& data)
435 {
436     NeighborhoodSearchTestData::TestPositionList::const_iterator i;
437     for (i = data.testPositions_.begin(); i != data.testPositions_.end(); ++i)
438     {
439         const gmx::AnalysisNeighborhoodPair pair = search->nearestPoint(i->x);
440         if (pair.isValid())
441         {
442             EXPECT_EQ(i->refNearestPoint, pair.refIndex());
443             EXPECT_EQ(0, pair.testIndex());
444             EXPECT_REAL_EQ_TOL(i->refMinDist, std::sqrt(pair.distance2()), data.relativeTolerance());
445         }
446         else
447         {
448             EXPECT_EQ(i->refNearestPoint, -1);
449         }
450     }
451 }
452
453 //! Helper function for formatting test failure messages.
454 std::string formatVector(const rvec x)
455 {
456     return gmx::formatString("[%.3f, %.3f, %.3f]", x[XX], x[YY], x[ZZ]);
457 }
458
459 /*! \brief
460  * Helper function to check that all expected pairs were found.
461  */
462 void checkAllPairsFound(const RefPairList&            refPairs,
463                         const std::vector<gmx::RVec>& refPos,
464                         int                           testPosIndex,
465                         const rvec                    testPos)
466 {
467     // This could be elegantly expressed with Google Mock matchers, but that
468     // has a significant effect on the runtime of the tests...
469     int                         count = 0;
470     RefPairList::const_iterator first;
471     for (RefPairList::const_iterator i = refPairs.begin(); i != refPairs.end(); ++i)
472     {
473         if (!i->bFound)
474         {
475             ++count;
476             first = i;
477         }
478     }
479     if (count > 0)
480     {
481         ADD_FAILURE() << "Some pairs (" << count << "/" << refPairs.size() << ") "
482                       << "within the cutoff were not found. First pair:\n"
483                       << " Ref: " << first->refIndex << " at "
484                       << formatVector(refPos[first->refIndex]) << "\n"
485                       << "Test: " << testPosIndex << " at " << formatVector(testPos) << "\n"
486                       << "Dist: " << first->distance;
487     }
488 }
489
490 void NeighborhoodSearchTest::testPairSearch(gmx::AnalysisNeighborhoodSearch*  search,
491                                             const NeighborhoodSearchTestData& data)
492 {
493     testPairSearchFull(search, data, data.testPositions(), nullptr, {}, {}, false);
494 }
495
496 void NeighborhoodSearchTest::testPairSearchIndexed(gmx::AnalysisNeighborhood*        nb,
497                                                    const NeighborhoodSearchTestData& data,
498                                                    uint64_t                          seed)
499 {
500     std::vector<int> refIndices(data.generateIndex(data.refPos_.size(), seed++));
501     std::vector<int> testIndices(data.generateIndex(data.testPositions_.size(), seed++));
502     gmx::AnalysisNeighborhoodSearch search =
503             nb->initSearch(&data.pbc_, data.refPositions().indexed(refIndices));
504     testPairSearchFull(&search, data, data.testPositions(), nullptr, refIndices, testIndices, false);
505 }
506
507 void NeighborhoodSearchTest::testPairSearchFull(gmx::AnalysisNeighborhoodSearch*          search,
508                                                 const NeighborhoodSearchTestData&         data,
509                                                 const gmx::AnalysisNeighborhoodPositions& pos,
510                                                 const gmx::ListOfLists<int>*              excls,
511                                                 const gmx::ArrayRef<const int>& refIndices,
512                                                 const gmx::ArrayRef<const int>& testIndices,
513                                                 bool                            selfPairs)
514 {
515     std::map<int, RefPairList> refPairs;
516     // TODO: Some parts of this code do not work properly if pos does not
517     // initially contain all the test positions.
518     if (testIndices.empty())
519     {
520         for (size_t i = 0; i < data.testPositions_.size(); ++i)
521         {
522             refPairs[i] = data.testPositions_[i].refPairs;
523         }
524     }
525     else
526     {
527         for (int index : testIndices)
528         {
529             refPairs[index] = data.testPositions_[index].refPairs;
530         }
531     }
532     if (excls != nullptr)
533     {
534         GMX_RELEASE_ASSERT(!selfPairs, "Self-pairs testing not implemented with exclusions");
535         for (auto& entry : refPairs)
536         {
537             const int testIndex = entry.first;
538             ExclusionsHelper::markExcludedPairs(&entry.second, testIndex, excls);
539         }
540     }
541     if (!refIndices.empty())
542     {
543         GMX_RELEASE_ASSERT(!selfPairs, "Self-pairs testing not implemented with indexing");
544         for (auto& entry : refPairs)
545         {
546             for (auto& refPair : entry.second)
547             {
548                 refPair.bIndexed = false;
549             }
550             for (int index : refIndices)
551             {
552                 NeighborhoodSearchTestData::RefPair searchPair(index, 0.0);
553                 auto refPair = std::lower_bound(entry.second.begin(), entry.second.end(), searchPair);
554                 if (refPair != entry.second.end() && refPair->refIndex == index)
555                 {
556                     refPair->bIndexed = true;
557                 }
558             }
559             for (auto& refPair : entry.second)
560             {
561                 if (!refPair.bIndexed)
562                 {
563                     refPair.bFound = true;
564                 }
565             }
566         }
567     }
568
569     gmx::AnalysisNeighborhoodPositions posCopy(pos);
570     if (!testIndices.empty())
571     {
572         posCopy.indexed(testIndices);
573     }
574     gmx::AnalysisNeighborhoodPairSearch pairSearch =
575             selfPairs ? search->startSelfPairSearch() : search->startPairSearch(posCopy);
576     gmx::AnalysisNeighborhoodPair pair;
577     while (pairSearch.findNextPair(&pair))
578     {
579         const int testIndex = (testIndices.empty() ? pair.testIndex() : testIndices[pair.testIndex()]);
580         const int refIndex = (refIndices.empty() ? pair.refIndex() : refIndices[pair.refIndex()]);
581
582         if (refPairs.count(testIndex) == 0)
583         {
584             ADD_FAILURE() << "Expected: No pairs are returned for test position " << testIndex << ".\n"
585                           << "  Actual: Pair with ref " << refIndex << " is returned.";
586             continue;
587         }
588         NeighborhoodSearchTestData::RefPair searchPair(refIndex, std::sqrt(pair.distance2()));
589         const auto                          foundRefPair =
590                 std::lower_bound(refPairs[testIndex].begin(), refPairs[testIndex].end(), searchPair);
591         if (foundRefPair == refPairs[testIndex].end() || foundRefPair->refIndex != refIndex)
592         {
593             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
594                           << ") is not within the cutoff.\n"
595                           << "  Actual: It is returned.";
596         }
597         else if (foundRefPair->bExcluded)
598         {
599             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
600                           << ") is excluded from the search.\n"
601                           << "  Actual: It is returned.";
602         }
603         else if (!foundRefPair->bIndexed)
604         {
605             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
606                           << ") is not part of the indexed set.\n"
607                           << "  Actual: It is returned.";
608         }
609         else if (foundRefPair->bFound)
610         {
611             ADD_FAILURE() << "Expected: Pair (ref: " << refIndex << ", test: " << testIndex
612                           << ") is returned only once.\n"
613                           << "  Actual: It is returned multiple times.";
614             return;
615         }
616         else
617         {
618             foundRefPair->bFound = true;
619
620             EXPECT_REAL_EQ_TOL(foundRefPair->distance, searchPair.distance, data.relativeTolerance())
621                     << "Distance computed by the neighborhood search does not match.";
622             if (selfPairs)
623             {
624                 searchPair              = NeighborhoodSearchTestData::RefPair(testIndex, 0.0);
625                 const auto otherRefPair = std::lower_bound(refPairs[refIndex].begin(),
626                                                            refPairs[refIndex].end(), searchPair);
627                 GMX_RELEASE_ASSERT(otherRefPair != refPairs[refIndex].end(),
628                                    "Precomputed reference data is not symmetric");
629                 otherRefPair->bFound = true;
630             }
631         }
632     }
633
634     for (auto& entry : refPairs)
635     {
636         const int testIndex = entry.first;
637         checkAllPairsFound(entry.second, data.refPos_, testIndex, data.testPositions_[testIndex].x);
638     }
639 }
640
641 /********************************************************************
642  * Test data generation
643  */
644
645 class TrivialTestData
646 {
647 public:
648     static const NeighborhoodSearchTestData& get()
649     {
650         static TrivialTestData singleton;
651         return singleton.data_;
652     }
653
654     TrivialTestData() : data_(12345, 1.0)
655     {
656         // Make the box so small we are virtually guaranteed to have
657         // several neighbors for the five test positions
658         data_.box_[XX][XX] = 3.0;
659         data_.box_[YY][YY] = 3.0;
660         data_.box_[ZZ][ZZ] = 3.0;
661         data_.generateRandomRefPositions(10);
662         data_.generateRandomTestPositions(5);
663         set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
664         data_.computeReferences(&data_.pbc_);
665     }
666
667 private:
668     NeighborhoodSearchTestData data_;
669 };
670
671 class TrivialSelfPairsTestData
672 {
673 public:
674     static const NeighborhoodSearchTestData& get()
675     {
676         static TrivialSelfPairsTestData singleton;
677         return singleton.data_;
678     }
679
680     TrivialSelfPairsTestData() : data_(12345, 1.0)
681     {
682         data_.box_[XX][XX] = 3.0;
683         data_.box_[YY][YY] = 3.0;
684         data_.box_[ZZ][ZZ] = 3.0;
685         data_.generateRandomRefPositions(20);
686         data_.useRefPositionsAsTestPositions();
687         set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
688         data_.computeReferences(&data_.pbc_);
689     }
690
691 private:
692     NeighborhoodSearchTestData data_;
693 };
694
695 class TrivialNoPBCTestData
696 {
697 public:
698     static const NeighborhoodSearchTestData& get()
699     {
700         static TrivialNoPBCTestData singleton;
701         return singleton.data_;
702     }
703
704     TrivialNoPBCTestData() : data_(12345, 1.0)
705     {
706         data_.generateRandomRefPositions(10);
707         data_.generateRandomTestPositions(5);
708         data_.computeReferences(nullptr);
709     }
710
711 private:
712     NeighborhoodSearchTestData data_;
713 };
714
715 class RandomBoxFullPBCData
716 {
717 public:
718     static const NeighborhoodSearchTestData& get()
719     {
720         static RandomBoxFullPBCData singleton;
721         return singleton.data_;
722     }
723
724     RandomBoxFullPBCData() : data_(12345, 1.0)
725     {
726         data_.box_[XX][XX] = 10.0;
727         data_.box_[YY][YY] = 5.0;
728         data_.box_[ZZ][ZZ] = 7.0;
729         // TODO: Consider whether manually picking some positions would give better
730         // test coverage.
731         data_.generateRandomRefPositions(1000);
732         data_.generateRandomTestPositions(100);
733         set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
734         data_.computeReferences(&data_.pbc_);
735     }
736
737 private:
738     NeighborhoodSearchTestData data_;
739 };
740
741 class RandomBoxSelfPairsData
742 {
743 public:
744     static const NeighborhoodSearchTestData& get()
745     {
746         static RandomBoxSelfPairsData singleton;
747         return singleton.data_;
748     }
749
750     RandomBoxSelfPairsData() : data_(12345, 1.0)
751     {
752         data_.box_[XX][XX] = 10.0;
753         data_.box_[YY][YY] = 5.0;
754         data_.box_[ZZ][ZZ] = 7.0;
755         data_.generateRandomRefPositions(1000);
756         data_.useRefPositionsAsTestPositions();
757         set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
758         data_.computeReferences(&data_.pbc_);
759     }
760
761 private:
762     NeighborhoodSearchTestData data_;
763 };
764
765 class RandomBoxXYFullPBCData
766 {
767 public:
768     static const NeighborhoodSearchTestData& get()
769     {
770         static RandomBoxXYFullPBCData singleton;
771         return singleton.data_;
772     }
773
774     RandomBoxXYFullPBCData() : data_(54321, 1.0)
775     {
776         data_.box_[XX][XX] = 10.0;
777         data_.box_[YY][YY] = 5.0;
778         data_.box_[ZZ][ZZ] = 7.0;
779         // TODO: Consider whether manually picking some positions would give better
780         // test coverage.
781         data_.generateRandomRefPositions(1000);
782         data_.generateRandomTestPositions(100);
783         set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
784         data_.computeReferencesXY(&data_.pbc_);
785     }
786
787 private:
788     NeighborhoodSearchTestData data_;
789 };
790
791 class RandomTriclinicFullPBCData
792 {
793 public:
794     static const NeighborhoodSearchTestData& get()
795     {
796         static RandomTriclinicFullPBCData singleton;
797         return singleton.data_;
798     }
799
800     RandomTriclinicFullPBCData() : data_(12345, 1.0)
801     {
802         data_.box_[XX][XX] = 5.0;
803         data_.box_[YY][XX] = 2.5;
804         data_.box_[YY][YY] = 2.5 * std::sqrt(3.0);
805         data_.box_[ZZ][XX] = 2.5;
806         data_.box_[ZZ][YY] = 2.5 * std::sqrt(1.0 / 3.0);
807         data_.box_[ZZ][ZZ] = 5.0 * std::sqrt(2.0 / 3.0);
808         // TODO: Consider whether manually picking some positions would give better
809         // test coverage.
810         data_.generateRandomRefPositions(1000);
811         data_.generateRandomTestPositions(100);
812         set_pbc(&data_.pbc_, epbcXYZ, data_.box_);
813         data_.computeReferences(&data_.pbc_);
814     }
815
816 private:
817     NeighborhoodSearchTestData data_;
818 };
819
820 class RandomBox2DPBCData
821 {
822 public:
823     static const NeighborhoodSearchTestData& get()
824     {
825         static RandomBox2DPBCData singleton;
826         return singleton.data_;
827     }
828
829     RandomBox2DPBCData() : data_(12345, 1.0)
830     {
831         data_.box_[XX][XX] = 10.0;
832         data_.box_[YY][YY] = 7.0;
833         data_.box_[ZZ][ZZ] = 5.0;
834         // TODO: Consider whether manually picking some positions would give better
835         // test coverage.
836         data_.generateRandomRefPositions(1000);
837         data_.generateRandomTestPositions(100);
838         set_pbc(&data_.pbc_, epbcXY, data_.box_);
839         data_.computeReferences(&data_.pbc_);
840     }
841
842 private:
843     NeighborhoodSearchTestData data_;
844 };
845
846 class RandomBoxNoPBCData
847 {
848 public:
849     static const NeighborhoodSearchTestData& get()
850     {
851         static RandomBoxNoPBCData singleton;
852         return singleton.data_;
853     }
854
855     RandomBoxNoPBCData() : data_(12345, 1.0)
856     {
857         data_.box_[XX][XX] = 10.0;
858         data_.box_[YY][YY] = 5.0;
859         data_.box_[ZZ][ZZ] = 7.0;
860         // TODO: Consider whether manually picking some positions would give better
861         // test coverage.
862         data_.generateRandomRefPositions(1000);
863         data_.generateRandomTestPositions(100);
864         set_pbc(&data_.pbc_, epbcNONE, data_.box_);
865         data_.computeReferences(nullptr);
866     }
867
868 private:
869     NeighborhoodSearchTestData data_;
870 };
871
872 /********************************************************************
873  * Actual tests
874  */
875
876 TEST_F(NeighborhoodSearchTest, SimpleSearch)
877 {
878     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
879
880     nb_.setCutoff(data.cutoff_);
881     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
882     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
883     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
884
885     testIsWithin(&search, data);
886     testMinimumDistance(&search, data);
887     testNearestPoint(&search, data);
888     testPairSearch(&search, data);
889
890     search.reset();
891     testPairSearchIndexed(&nb_, data, 123);
892 }
893
894 TEST_F(NeighborhoodSearchTest, SimpleSearchXY)
895 {
896     const NeighborhoodSearchTestData& data = RandomBoxXYFullPBCData::get();
897
898     nb_.setCutoff(data.cutoff_);
899     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
900     nb_.setXYMode(true);
901     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
902     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
903
904     testIsWithin(&search, data);
905     testMinimumDistance(&search, data);
906     testNearestPoint(&search, data);
907     testPairSearch(&search, data);
908 }
909
910 TEST_F(NeighborhoodSearchTest, GridSearchBox)
911 {
912     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
913
914     nb_.setCutoff(data.cutoff_);
915     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
916     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
917     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
918
919     testIsWithin(&search, data);
920     testMinimumDistance(&search, data);
921     testNearestPoint(&search, data);
922     testPairSearch(&search, data);
923
924     search.reset();
925     testPairSearchIndexed(&nb_, data, 456);
926 }
927
928 TEST_F(NeighborhoodSearchTest, GridSearchTriclinic)
929 {
930     const NeighborhoodSearchTestData& data = RandomTriclinicFullPBCData::get();
931
932     nb_.setCutoff(data.cutoff_);
933     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
934     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
935     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
936
937     testPairSearch(&search, data);
938 }
939
940 TEST_F(NeighborhoodSearchTest, GridSearch2DPBC)
941 {
942     const NeighborhoodSearchTestData& data = RandomBox2DPBCData::get();
943
944     nb_.setCutoff(data.cutoff_);
945     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
946     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
947     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
948
949     testIsWithin(&search, data);
950     testMinimumDistance(&search, data);
951     testNearestPoint(&search, data);
952     testPairSearch(&search, data);
953 }
954
955 TEST_F(NeighborhoodSearchTest, GridSearchNoPBC)
956 {
957     const NeighborhoodSearchTestData& data = RandomBoxNoPBCData::get();
958
959     nb_.setCutoff(data.cutoff_);
960     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
961     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
962     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
963
964     testPairSearch(&search, data);
965 }
966
967 TEST_F(NeighborhoodSearchTest, GridSearchXYBox)
968 {
969     const NeighborhoodSearchTestData& data = RandomBoxXYFullPBCData::get();
970
971     nb_.setCutoff(data.cutoff_);
972     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
973     nb_.setXYMode(true);
974     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
975     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
976
977     testIsWithin(&search, data);
978     testMinimumDistance(&search, data);
979     testNearestPoint(&search, data);
980     testPairSearch(&search, data);
981 }
982
983 TEST_F(NeighborhoodSearchTest, SimpleSelfPairsSearch)
984 {
985     const NeighborhoodSearchTestData& data = TrivialSelfPairsTestData::get();
986
987     nb_.setCutoff(data.cutoff_);
988     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
989     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
990     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
991
992     testPairSearchFull(&search, data, data.testPositions(), nullptr, {}, {}, true);
993 }
994
995 TEST_F(NeighborhoodSearchTest, GridSelfPairsSearch)
996 {
997     const NeighborhoodSearchTestData& data = RandomBoxSelfPairsData::get();
998
999     nb_.setCutoff(data.cutoff_);
1000     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
1001     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
1002     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
1003
1004     testPairSearchFull(&search, data, data.testPositions(), nullptr, {}, {}, true);
1005 }
1006
1007 TEST_F(NeighborhoodSearchTest, HandlesConcurrentSearches)
1008 {
1009     const NeighborhoodSearchTestData& data = TrivialTestData::get();
1010
1011     nb_.setCutoff(data.cutoff_);
1012     gmx::AnalysisNeighborhoodSearch search1 = nb_.initSearch(&data.pbc_, data.refPositions());
1013     gmx::AnalysisNeighborhoodSearch search2 = nb_.initSearch(&data.pbc_, data.refPositions());
1014
1015     // These checks are fragile, and unfortunately depend on the random
1016     // engine used to create the test positions. There is no particular reason
1017     // why exactly particles 0 & 2 should have neighbors, but in this case they do.
1018     gmx::AnalysisNeighborhoodPairSearch pairSearch1 = search1.startPairSearch(data.testPosition(0));
1019     gmx::AnalysisNeighborhoodPairSearch pairSearch2 = search1.startPairSearch(data.testPosition(2));
1020
1021     testPairSearch(&search2, data);
1022
1023     gmx::AnalysisNeighborhoodPair pair;
1024     ASSERT_TRUE(pairSearch1.findNextPair(&pair))
1025             << "Test data did not contain any pairs for position 0 (problem in the test).";
1026     EXPECT_EQ(0, pair.testIndex());
1027     {
1028         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), std::sqrt(pair.distance2()));
1029         EXPECT_TRUE(data.containsPair(0, searchPair));
1030     }
1031
1032     ASSERT_TRUE(pairSearch2.findNextPair(&pair))
1033             << "Test data did not contain any pairs for position 2 (problem in the test).";
1034     EXPECT_EQ(2, pair.testIndex());
1035     {
1036         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), std::sqrt(pair.distance2()));
1037         EXPECT_TRUE(data.containsPair(2, searchPair));
1038     }
1039 }
1040
1041 TEST_F(NeighborhoodSearchTest, HandlesNoPBC)
1042 {
1043     const NeighborhoodSearchTestData& data = TrivialNoPBCTestData::get();
1044
1045     nb_.setCutoff(data.cutoff_);
1046     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(&data.pbc_, data.refPositions());
1047     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
1048
1049     testIsWithin(&search, data);
1050     testMinimumDistance(&search, data);
1051     testNearestPoint(&search, data);
1052     testPairSearch(&search, data);
1053 }
1054
1055 TEST_F(NeighborhoodSearchTest, HandlesNullPBC)
1056 {
1057     const NeighborhoodSearchTestData& data = TrivialNoPBCTestData::get();
1058
1059     nb_.setCutoff(data.cutoff_);
1060     gmx::AnalysisNeighborhoodSearch search = nb_.initSearch(nullptr, data.refPositions());
1061     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
1062
1063     testIsWithin(&search, data);
1064     testMinimumDistance(&search, data);
1065     testNearestPoint(&search, data);
1066     testPairSearch(&search, data);
1067 }
1068
1069 TEST_F(NeighborhoodSearchTest, HandlesSkippingPairs)
1070 {
1071     const NeighborhoodSearchTestData& data = TrivialTestData::get();
1072
1073     nb_.setCutoff(data.cutoff_);
1074     gmx::AnalysisNeighborhoodSearch     search = nb_.initSearch(&data.pbc_, data.refPositions());
1075     gmx::AnalysisNeighborhoodPairSearch pairSearch = search.startPairSearch(data.testPositions());
1076     gmx::AnalysisNeighborhoodPair       pair;
1077     // TODO: This test needs to be adjusted if the grid search gets optimized
1078     // to loop over the test positions in cell order (first, the ordering
1079     // assumption here breaks, and second, it then needs to be tested
1080     // separately for simple and grid searches).
1081     int currentIndex = 0;
1082     while (pairSearch.findNextPair(&pair))
1083     {
1084         while (currentIndex < pair.testIndex())
1085         {
1086             ++currentIndex;
1087         }
1088         EXPECT_EQ(currentIndex, pair.testIndex());
1089         NeighborhoodSearchTestData::RefPair searchPair(pair.refIndex(), std::sqrt(pair.distance2()));
1090         EXPECT_TRUE(data.containsPair(currentIndex, searchPair));
1091         pairSearch.skipRemainingPairsForTestPosition();
1092         ++currentIndex;
1093     }
1094 }
1095
1096 TEST_F(NeighborhoodSearchTest, SimpleSearchExclusions)
1097 {
1098     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
1099
1100     ExclusionsHelper helper(data.refPosCount_, data.testPositions_.size());
1101     helper.generateExclusions();
1102
1103     nb_.setCutoff(data.cutoff_);
1104     nb_.setTopologyExclusions(helper.exclusions());
1105     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Simple);
1106     gmx::AnalysisNeighborhoodSearch search =
1107             nb_.initSearch(&data.pbc_, data.refPositions().exclusionIds(helper.refPosIds()));
1108     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Simple, search.mode());
1109
1110     testPairSearchFull(&search, data, data.testPositions().exclusionIds(helper.testPosIds()),
1111                        helper.exclusions(), {}, {}, false);
1112 }
1113
1114 TEST_F(NeighborhoodSearchTest, GridSearchExclusions)
1115 {
1116     const NeighborhoodSearchTestData& data = RandomBoxFullPBCData::get();
1117
1118     ExclusionsHelper helper(data.refPosCount_, data.testPositions_.size());
1119     helper.generateExclusions();
1120
1121     nb_.setCutoff(data.cutoff_);
1122     nb_.setTopologyExclusions(helper.exclusions());
1123     nb_.setMode(gmx::AnalysisNeighborhood::eSearchMode_Grid);
1124     gmx::AnalysisNeighborhoodSearch search =
1125             nb_.initSearch(&data.pbc_, data.refPositions().exclusionIds(helper.refPosIds()));
1126     ASSERT_EQ(gmx::AnalysisNeighborhood::eSearchMode_Grid, search.mode());
1127
1128     testPairSearchFull(&search, data, data.testPositions().exclusionIds(helper.testPosIds()),
1129                        helper.exclusions(), {}, {}, false);
1130 }
1131
1132 } // namespace