From ee7638d8e6a8d5e35a8264d120e3434afa8405b8 Mon Sep 17 00:00:00 2001 From: Teemu Murtola Date: Sat, 23 Aug 2014 06:48:55 +0300 Subject: [PATCH] Precalculate pbc shift for analysis nbsearch Instead of using pbc_dx_aiuc(), precalculate the PBC shift between grid cells outside the inner loop when doing grid searching for analysis neighborhood searching. In addition to improving the performance, this encapsulates another piece of code that needs to be changed to implement more generic grids. Change-Id: Ifbbe54596f820b01572fe7bb97a5354556a4981d --- src/gromacs/selection/nbsearch.cpp | 54 ++++++++++++++++++------ src/gromacs/selection/tests/nbsearch.cpp | 33 ++++++++++++--- 2 files changed, 68 insertions(+), 19 deletions(-) diff --git a/src/gromacs/selection/nbsearch.cpp b/src/gromacs/selection/nbsearch.cpp index 0b50214a4d..d626b66d4e 100644 --- a/src/gromacs/selection/nbsearch.cpp +++ b/src/gromacs/selection/nbsearch.cpp @@ -40,9 +40,6 @@ * The grid implementation could still be optimized in several different ways: * - Triclinic grid cells are not the most efficient shape, but make PBC * handling easier. - * - Precalculating the required PBC shift for a pair of cells outside the - * inner loop. After this is done, it should be quite straightforward to - * move to rectangular cells. * - Pruning grid cells from the search list if they are completely outside * the sphere that is being considered. * - A better heuristic could be added for falling back to simple loops for a @@ -159,6 +156,16 @@ class AnalysisNeighborhoodSearchImpl * \param[in] i Index to add. */ void addToGridCell(const ivec cell, int i); + /*! \brief + * Calculates the index of a neighboring grid cell. + * + * \param[in] sourceCell Location of the source cell. + * \param[in] index Index of the neighbor to calculate. + * \param[out] shift Shift to apply to get the periodic distance + * for distances between the cells. + * \returns Grid cell index of the neighboring cell. + */ + int getNeighboringCell(const ivec sourceCell, int index, rvec shift) const; //! Whether to try grid searching. bool bTryGrid_; @@ -507,6 +514,33 @@ void AnalysisNeighborhoodSearchImpl::addToGridCell(const ivec cell, int i) cells_[ci].push_back(i); } +int AnalysisNeighborhoodSearchImpl::getNeighboringCell( + const ivec sourceCell, int index, rvec shift) const +{ + ivec cell; + ivec_add(sourceCell, gnboffs_[index], cell); + + // TODO: Consider unifying with the similar shifting code in + // mapPointToGridCell(). + clear_rvec(shift); + for (int d = 0; d < DIM; ++d) + { + const int cellCount = ncelldim_[d]; + if (cell[d] < 0) + { + cell[d] += cellCount; + rvec_add(shift, pbc_.box[d], shift); + } + else if (cell[d] >= cellCount) + { + cell[d] -= cellCount; + rvec_sub(shift, pbc_.box[d], shift); + } + } + + return getGridCellIndex(cell); +} + void AnalysisNeighborhoodSearchImpl::init( AnalysisNeighborhood::SearchMode mode, bool bXY, @@ -688,16 +722,9 @@ bool AnalysisNeighborhoodPairSearchImpl::searchNext(Action action) for (; nbi < search_.ngridnb_; ++nbi) { - ivec cell; - - ivec_add(testcell_, search_.gnboffs_[nbi], cell); - cell[XX] = (cell[XX] + search_.ncelldim_[XX]) % search_.ncelldim_[XX]; - cell[YY] = (cell[YY] + search_.ncelldim_[YY]) % search_.ncelldim_[YY]; - cell[ZZ] = (cell[ZZ] + search_.ncelldim_[ZZ]) % search_.ncelldim_[ZZ]; - - const int ci = search_.getGridCellIndex(cell); + rvec shift; + const int ci = search_.getNeighboringCell(testcell_, nbi, shift); const int cellSize = static_cast(search_.cells_[ci].size()); - /* TODO: Calculate the required PBC shift outside the inner loop */ for (; cai < cellSize; ++cai) { const int i = search_.cells_[ci][cai]; @@ -706,7 +733,8 @@ bool AnalysisNeighborhoodPairSearchImpl::searchNext(Action action) continue; } rvec dx; - pbc_dx_aiuc(&search_.pbc_, xtest_, search_.xref_[i], dx); + rvec_sub(xtest_, search_.xref_[i], dx); + rvec_add(dx, shift, dx); const real r2 = norm2(dx); if (r2 <= search_.cutoff2_) { diff --git a/src/gromacs/selection/tests/nbsearch.cpp b/src/gromacs/selection/tests/nbsearch.cpp index 1d89772d2a..ac7cba62a5 100644 --- a/src/gromacs/selection/tests/nbsearch.cpp +++ b/src/gromacs/selection/tests/nbsearch.cpp @@ -62,6 +62,7 @@ #include "gromacs/random/random.h" #include "gromacs/topology/block.h" #include "gromacs/utility/smalloc.h" +#include "gromacs/utility/stringutil.h" #include "testutils/testasserts.h" @@ -454,22 +455,40 @@ void NeighborhoodSearchTest::testNearestPoint( } } +//! Helper function for formatting test failure messages. +std::string formatVector(const rvec x) +{ + return gmx::formatString("[%.3f, %.3f, %.3f]", x[XX], x[YY], x[ZZ]); +} + /*! \brief * Helper function to check that all expected pairs were found. */ -static void checkAllPairsFound(const RefPairList &refPairs) +void checkAllPairsFound(const RefPairList &refPairs, const rvec refPos[], + int testPosIndex, const rvec testPos) { // This could be elegantly expressed with Google Mock matchers, but that // has a significant effect on the runtime of the tests... + int count = 0; + RefPairList::const_iterator first; for (RefPairList::const_iterator i = refPairs.begin(); i != refPairs.end(); ++i) { if (!i->bFound) { - ADD_FAILURE() - << "Some pairs within the cutoff were not found."; - break; + ++count; + first = i; } } + if (count > 0) + { + ADD_FAILURE() + << "Some pairs (" << count << "/" << refPairs.size() << ") " + << "within the cutoff were not found. First pair:\n" + << " Ref: " << first->refIndex << " at " + << formatVector(refPos[first->refIndex]) << "\n" + << "Test: " << testPosIndex << " at " << formatVector(testPos) << "\n" + << "Dist: " << first->distance; + } } void NeighborhoodSearchTest::testPairSearch( @@ -505,7 +524,8 @@ void NeighborhoodSearchTest::testPairSearchFull( { if (prevTestPos != -1) { - checkAllPairsFound(refPairs); + checkAllPairsFound(refPairs, data.refPos_, prevTestPos, + data.testPositions_[prevTestPos].x); } const int testIndex = pair.testIndex(); if (remainingTestPositions.count(testIndex) == 0) @@ -556,7 +576,8 @@ void NeighborhoodSearchTest::testPairSearchFull( << "Distance computed by the neighborhood search does not match."; } } - checkAllPairsFound(refPairs); + checkAllPairsFound(refPairs, data.refPos_, prevTestPos, + data.testPositions_[prevTestPos].x); for (std::set::const_iterator i = remainingTestPositions.begin(); i != remainingTestPositions.end(); ++i) { -- 2.22.0