Sort all includes in src/gromacs
[alexxy/gromacs.git] / src / gromacs / selection / nbsearch.cpp
index 1e8f1cfdadfe46617dab5a475ca22e327ef95578..7e32fa6536b1ad066a2d7cdb1bb9df420a9fc742 100644 (file)
@@ -40,9 +40,6 @@
  * The grid implementation could still be optimized in several different ways:
  *   - Triclinic grid cells are not the most efficient shape, but make PBC
  *     handling easier.
- *   - Precalculating the required PBC shift for a pair of cells outside the
- *     inner loop. After this is done, it should be quite straightforward to
- *     move to rectangular cells.
  *   - Pruning grid cells from the search list if they are completely outside
  *     the sphere that is being considered.
  *   - A better heuristic could be added for falling back to simple loops for a
@@ -55,7 +52,9 @@
  * \author Teemu Murtola <teemu.murtola@gmail.com>
  * \ingroup module_selection
  */
-#include "gromacs/selection/nbsearch.h"
+#include "gmxpre.h"
+
+#include "nbsearch.h"
 
 #include <cmath>
 #include <cstring>
@@ -66,7 +65,6 @@
 #include "thread_mpi/mutex.h"
 
 #include "gromacs/legacyheaders/names.h"
-
 #include "gromacs/math/vec.h"
 #include "gromacs/pbcutil/pbc.h"
 #include "gromacs/selection/position.h"
@@ -139,10 +137,10 @@ class AnalysisNeighborhoodSearchImpl
          *
          * \param[in]  x    Point to map.
          * \param[out] cell Indices of the grid cell in which \p x lies.
-         *
-         * \p x should be within the triclinic unit cell.
+         * \param[out] xout Coordinates to use
+         *     (will be within the triclinic unit cell).
          */
-        void mapPointToGridCell(const rvec x, ivec cell) const;
+        void mapPointToGridCell(const rvec x, ivec cell, rvec xout) const;
         /*! \brief
          * Calculates linear index of a grid cell.
          *
@@ -157,6 +155,16 @@ class AnalysisNeighborhoodSearchImpl
          * \param[in]     i    Index to add.
          */
         void addToGridCell(const ivec cell, int i);
+        /*! \brief
+         * Calculates the index of a neighboring grid cell.
+         *
+         * \param[in]  sourceCell Location of the source cell.
+         * \param[in]  index      Index of the neighbor to calculate.
+         * \param[out] shift      Shift to apply to get the periodic distance
+         *     for distances between the cells.
+         * \returns    Grid cell index of the neighboring cell.
+         */
+        int getNeighboringCell(const ivec sourceCell, int index, rvec shift) const;
 
         //! Whether to try grid searching.
         bool                    bTryGrid_;
@@ -439,25 +447,52 @@ bool AnalysisNeighborhoodSearchImpl::initGrid(const t_pbc &pbc)
 }
 
 void AnalysisNeighborhoodSearchImpl::mapPointToGridCell(const rvec x,
-                                                        ivec       cell) const
+                                                        ivec       cell,
+                                                        rvec       xout) const
 {
+    rvec xtmp;
+    copy_rvec(x, xtmp);
     if (bTric_)
     {
         rvec tx;
-
-        tmvmul_ur0(recipcell_, x, tx);
+        tmvmul_ur0(recipcell_, xtmp, tx);
         for (int dd = 0; dd < DIM; ++dd)
         {
-            cell[dd] = static_cast<int>(tx[dd]);
+            const int cellCount = ncelldim_[dd];
+            int       cellIndex = static_cast<int>(floor(tx[dd]));
+            while (cellIndex < 0)
+            {
+                cellIndex += cellCount;
+                rvec_add(xtmp, pbc_.box[dd], xtmp);
+            }
+            while (cellIndex >= cellCount)
+            {
+                cellIndex -= cellCount;
+                rvec_sub(xtmp, pbc_.box[dd], xtmp);
+            }
+            cell[dd] = cellIndex;
         }
     }
     else
     {
         for (int dd = 0; dd < DIM; ++dd)
         {
-            cell[dd] = static_cast<int>(x[dd] * recipcell_[dd][dd]);
+            const int cellCount = ncelldim_[dd];
+            int       cellIndex = static_cast<int>(floor(xtmp[dd] * recipcell_[dd][dd]));
+            while (cellIndex < 0)
+            {
+                cellIndex += cellCount;
+                xtmp[dd]  += pbc_.box[dd][dd];
+            }
+            while (cellIndex >= cellCount)
+            {
+                cellIndex -= cellCount;
+                xtmp[dd]  -= pbc_.box[dd][dd];
+            }
+            cell[dd] = cellIndex;
         }
     }
+    copy_rvec(xtmp, xout);
 }
 
 int AnalysisNeighborhoodSearchImpl::getGridCellIndex(const ivec cell) const
@@ -478,6 +513,33 @@ void AnalysisNeighborhoodSearchImpl::addToGridCell(const ivec cell, int i)
     cells_[ci].push_back(i);
 }
 
+int AnalysisNeighborhoodSearchImpl::getNeighboringCell(
+        const ivec sourceCell, int index, rvec shift) const
+{
+    ivec cell;
+    ivec_add(sourceCell, gnboffs_[index], cell);
+
+    // TODO: Consider unifying with the similar shifting code in
+    // mapPointToGridCell().
+    clear_rvec(shift);
+    for (int d = 0; d < DIM; ++d)
+    {
+        const int cellCount = ncelldim_[d];
+        if (cell[d] < 0)
+        {
+            cell[d] += cellCount;
+            rvec_add(shift, pbc_.box[d], shift);
+        }
+        else if (cell[d] >= cellCount)
+        {
+            cell[d] -= cellCount;
+            rvec_sub(shift, pbc_.box[d], shift);
+        }
+    }
+
+    return getGridCellIndex(cell);
+}
+
 void AnalysisNeighborhoodSearchImpl::init(
         AnalysisNeighborhood::SearchMode     mode,
         bool                                 bXY,
@@ -534,17 +596,10 @@ void AnalysisNeighborhoodSearchImpl::init(
         }
         xref_ = xref_alloc_;
 
-        for (int i = 0; i < nref_; ++i)
-        {
-            copy_rvec(positions.x_[i], xref_alloc_[i]);
-        }
-        put_atoms_in_triclinic_unitcell(ecenterTRIC, pbc_.box,
-                                        nref_, xref_alloc_);
         for (int i = 0; i < nref_; ++i)
         {
             ivec refcell;
-
-            mapPointToGridCell(xref_[i], refcell);
+            mapPointToGridCell(positions.x_[i], refcell, xref_alloc_[i]);
             addToGridCell(refcell, i);
         }
     }
@@ -576,10 +631,11 @@ void AnalysisNeighborhoodPairSearchImpl::reset(int testIndex)
         copy_rvec(testPositions_[testIndex_], xtest_);
         if (search_.bGrid_)
         {
-            put_atoms_in_triclinic_unitcell(ecenterTRIC,
-                                            const_cast<rvec *>(search_.pbc_.box),
-                                            1, &xtest_);
-            search_.mapPointToGridCell(xtest_, testcell_);
+            search_.mapPointToGridCell(testPositions_[testIndex], testcell_, xtest_);
+        }
+        else
+        {
+            copy_rvec(testPositions_[testIndex_], xtest_);
         }
         if (search_.excls_ != NULL)
         {
@@ -665,16 +721,9 @@ bool AnalysisNeighborhoodPairSearchImpl::searchNext(Action action)
 
             for (; nbi < search_.ngridnb_; ++nbi)
             {
-                ivec cell;
-
-                ivec_add(testcell_, search_.gnboffs_[nbi], cell);
-                cell[XX] = (cell[XX] + search_.ncelldim_[XX]) % search_.ncelldim_[XX];
-                cell[YY] = (cell[YY] + search_.ncelldim_[YY]) % search_.ncelldim_[YY];
-                cell[ZZ] = (cell[ZZ] + search_.ncelldim_[ZZ]) % search_.ncelldim_[ZZ];
-
-                const int ci       = search_.getGridCellIndex(cell);
+                rvec      shift;
+                const int ci       = search_.getNeighboringCell(testcell_, nbi, shift);
                 const int cellSize = static_cast<int>(search_.cells_[ci].size());
-                /* TODO: Calculate the required PBC shift outside the inner loop */
                 for (; cai < cellSize; ++cai)
                 {
                     const int i = search_.cells_[ci][cai];
@@ -683,7 +732,8 @@ bool AnalysisNeighborhoodPairSearchImpl::searchNext(Action action)
                         continue;
                     }
                     rvec       dx;
-                    pbc_dx_aiuc(&search_.pbc_, xtest_, search_.xref_[i], dx);
+                    rvec_sub(xtest_, search_.xref_[i], dx);
+                    rvec_add(dx, shift, dx);
                     const real r2 = norm2(dx);
                     if (r2 <= search_.cutoff2_)
                     {