Precalculate pbc shift for analysis nbsearch
authorTeemu Murtola <teemu.murtola@gmail.com>
Sat, 23 Aug 2014 03:48:55 +0000 (06:48 +0300)
committerRoland Schulz <roland@rschulz.eu>
Sun, 24 Aug 2014 04:56:04 +0000 (06:56 +0200)
Instead of using pbc_dx_aiuc(), precalculate the PBC shift between grid
cells outside the inner loop when doing grid searching for analysis
neighborhood searching.  In addition to improving the performance, this
encapsulates another piece of code that needs to be changed to implement
more generic grids.

Change-Id: Ifbbe54596f820b01572fe7bb97a5354556a4981d

src/gromacs/selection/nbsearch.cpp
src/gromacs/selection/tests/nbsearch.cpp

index 0b50214a4dd5e234b7a160fe254c994e31a4e5f8..d626b66d4e98534da51c71cd21af6a6563a85741 100644 (file)
@@ -40,9 +40,6 @@
  * The grid implementation could still be optimized in several different ways:
  *   - Triclinic grid cells are not the most efficient shape, but make PBC
  *     handling easier.
- *   - Precalculating the required PBC shift for a pair of cells outside the
- *     inner loop. After this is done, it should be quite straightforward to
- *     move to rectangular cells.
  *   - Pruning grid cells from the search list if they are completely outside
  *     the sphere that is being considered.
  *   - A better heuristic could be added for falling back to simple loops for a
@@ -159,6 +156,16 @@ class AnalysisNeighborhoodSearchImpl
          * \param[in]     i    Index to add.
          */
         void addToGridCell(const ivec cell, int i);
+        /*! \brief
+         * Calculates the index of a neighboring grid cell.
+         *
+         * \param[in]  sourceCell Location of the source cell.
+         * \param[in]  index      Index of the neighbor to calculate.
+         * \param[out] shift      Shift to apply to get the periodic distance
+         *     for distances between the cells.
+         * \returns    Grid cell index of the neighboring cell.
+         */
+        int getNeighboringCell(const ivec sourceCell, int index, rvec shift) const;
 
         //! Whether to try grid searching.
         bool                    bTryGrid_;
@@ -507,6 +514,33 @@ void AnalysisNeighborhoodSearchImpl::addToGridCell(const ivec cell, int i)
     cells_[ci].push_back(i);
 }
 
+int AnalysisNeighborhoodSearchImpl::getNeighboringCell(
+        const ivec sourceCell, int index, rvec shift) const
+{
+    ivec cell;
+    ivec_add(sourceCell, gnboffs_[index], cell);
+
+    // TODO: Consider unifying with the similar shifting code in
+    // mapPointToGridCell().
+    clear_rvec(shift);
+    for (int d = 0; d < DIM; ++d)
+    {
+        const int cellCount = ncelldim_[d];
+        if (cell[d] < 0)
+        {
+            cell[d] += cellCount;
+            rvec_add(shift, pbc_.box[d], shift);
+        }
+        else if (cell[d] >= cellCount)
+        {
+            cell[d] -= cellCount;
+            rvec_sub(shift, pbc_.box[d], shift);
+        }
+    }
+
+    return getGridCellIndex(cell);
+}
+
 void AnalysisNeighborhoodSearchImpl::init(
         AnalysisNeighborhood::SearchMode     mode,
         bool                                 bXY,
@@ -688,16 +722,9 @@ bool AnalysisNeighborhoodPairSearchImpl::searchNext(Action action)
 
             for (; nbi < search_.ngridnb_; ++nbi)
             {
-                ivec cell;
-
-                ivec_add(testcell_, search_.gnboffs_[nbi], cell);
-                cell[XX] = (cell[XX] + search_.ncelldim_[XX]) % search_.ncelldim_[XX];
-                cell[YY] = (cell[YY] + search_.ncelldim_[YY]) % search_.ncelldim_[YY];
-                cell[ZZ] = (cell[ZZ] + search_.ncelldim_[ZZ]) % search_.ncelldim_[ZZ];
-
-                const int ci       = search_.getGridCellIndex(cell);
+                rvec      shift;
+                const int ci       = search_.getNeighboringCell(testcell_, nbi, shift);
                 const int cellSize = static_cast<int>(search_.cells_[ci].size());
-                /* TODO: Calculate the required PBC shift outside the inner loop */
                 for (; cai < cellSize; ++cai)
                 {
                     const int i = search_.cells_[ci][cai];
@@ -706,7 +733,8 @@ bool AnalysisNeighborhoodPairSearchImpl::searchNext(Action action)
                         continue;
                     }
                     rvec       dx;
-                    pbc_dx_aiuc(&search_.pbc_, xtest_, search_.xref_[i], dx);
+                    rvec_sub(xtest_, search_.xref_[i], dx);
+                    rvec_add(dx, shift, dx);
                     const real r2 = norm2(dx);
                     if (r2 <= search_.cutoff2_)
                     {
index 1d89772d2a197883fb94864b6cec07c1390821da..ac7cba62a518b9b0cb901c33844d53faec7299d8 100644 (file)
@@ -62,6 +62,7 @@
 #include "gromacs/random/random.h"
 #include "gromacs/topology/block.h"
 #include "gromacs/utility/smalloc.h"
+#include "gromacs/utility/stringutil.h"
 
 #include "testutils/testasserts.h"
 
@@ -454,22 +455,40 @@ void NeighborhoodSearchTest::testNearestPoint(
     }
 }
 
+//! Helper function for formatting test failure messages.
+std::string formatVector(const rvec x)
+{
+    return gmx::formatString("[%.3f, %.3f, %.3f]", x[XX], x[YY], x[ZZ]);
+}
+
 /*! \brief
  * Helper function to check that all expected pairs were found.
  */
-static void checkAllPairsFound(const RefPairList &refPairs)
+void checkAllPairsFound(const RefPairList &refPairs, const rvec refPos[],
+                        int testPosIndex, const rvec testPos)
 {
     // This could be elegantly expressed with Google Mock matchers, but that
     // has a significant effect on the runtime of the tests...
+    int                         count = 0;
+    RefPairList::const_iterator first;
     for (RefPairList::const_iterator i = refPairs.begin(); i != refPairs.end(); ++i)
     {
         if (!i->bFound)
         {
-            ADD_FAILURE()
-            << "Some pairs within the cutoff were not found.";
-            break;
+            ++count;
+            first = i;
         }
     }
+    if (count > 0)
+    {
+        ADD_FAILURE()
+        << "Some pairs (" << count << "/" << refPairs.size() << ") "
+        << "within the cutoff were not found. First pair:\n"
+        << " Ref: " << first->refIndex << " at "
+        << formatVector(refPos[first->refIndex]) << "\n"
+        << "Test: " << testPosIndex << " at " << formatVector(testPos) << "\n"
+        << "Dist: " << first->distance;
+    }
 }
 
 void NeighborhoodSearchTest::testPairSearch(
@@ -505,7 +524,8 @@ void NeighborhoodSearchTest::testPairSearchFull(
         {
             if (prevTestPos != -1)
             {
-                checkAllPairsFound(refPairs);
+                checkAllPairsFound(refPairs, data.refPos_, prevTestPos,
+                                   data.testPositions_[prevTestPos].x);
             }
             const int testIndex = pair.testIndex();
             if (remainingTestPositions.count(testIndex) == 0)
@@ -556,7 +576,8 @@ void NeighborhoodSearchTest::testPairSearchFull(
             << "Distance computed by the neighborhood search does not match.";
         }
     }
-    checkAllPairsFound(refPairs);
+    checkAllPairsFound(refPairs, data.refPos_, prevTestPos,
+                       data.testPositions_[prevTestPos].x);
     for (std::set<int>::const_iterator i = remainingTestPositions.begin();
          i != remainingTestPositions.end(); ++i)
     {