Add size optimization to HashedMap
authorBerk Hess <hess@kth.se>
Fri, 10 Aug 2018 11:16:27 +0000 (13:16 +0200)
committerBerk Hess <hess@kth.se>
Fri, 24 Aug 2018 10:35:59 +0000 (12:35 +0200)
The table size in HashedMap is now optimized when calling clear()
using the old number of keys. Also the number of keys is now set
to a power of 2, so we can use bit masking instead of modulo.
The bit masking allows for negative keys, which is also tested.

This is preparation for replacing gmx_hash_t with HashedMap,
but also improves performance for gmx_ga2la_t.

Change-Id: I90c5a602cb7e213eb6d2e8259a0effc4fd7c4e14

src/gromacs/domdec/ga2la.cpp
src/gromacs/domdec/hashedmap.h
src/gromacs/domdec/tests/hashedmap.cpp

index 3ab76677608736df044b9420280254e6ce0d16e1..071e9697e48fb76cbec532e4fcec176b1ffddf10 100644 (file)
@@ -69,20 +69,6 @@ static bool directListIsFaster(int numAtomsTotal,
             numAtomsTotal <= numAtomsLocal*c_memoryRatioHashedVersusDirect);
 }
 
-/*! \brief Returns the base size of the hash table
- *
- * Make the direct list twice as long as the number of local atoms.
- * The fraction of entries in the list with:
- * 0   size lists: e^-1/f
- * >=1 size lists: 1 - e^-1/f
- * where f is: the direct list length / nr. of local atoms
- * The fraction of atoms not in the direct list is: 1-f(1-e^-1/f).
- */
-static int baseTableSizeForHashTable(int numAtomsLocal)
-{
-    return 2*numAtomsLocal;
-}
-
 gmx_ga2la_t::gmx_ga2la_t(int numAtomsTotal,
                          int numAtomsLocal) :
     usingDirect_(directListIsFaster(numAtomsTotal, numAtomsLocal))
@@ -93,6 +79,6 @@ gmx_ga2la_t::gmx_ga2la_t(int numAtomsTotal,
     }
     else
     {
-        new(&(data_.hashed)) gmx::HashedMap<Entry>(baseTableSizeForHashTable(numAtomsLocal));
+        new(&(data_.hashed)) gmx::HashedMap<Entry>(numAtomsLocal);
     }
 }
index 90b8c4cd8abb2c5d4470cb4c35c7061ee8601cc3..8370361a0e49668111430a24f6cf20d5cf4778c1 100644 (file)
@@ -47,6 +47,8 @@
 #ifndef GMX_DOMDEC_HASHEDMAP_H
 #define GMX_DOMDEC_HASHEDMAP_H
 
+#include <climits>
+
 #include <algorithm>
 #include <vector>
 
@@ -59,7 +61,7 @@ namespace gmx
 
 /*! \libinternal \brief Unordered key to value mapping
  *
- * Efficiently manages mapping from non-negative integer keys to values.
+ * Efficiently manages mapping from integer keys to values.
  * Note that this basically implements a subset of the functionality of
  * std::unordered_map, but is an order of magnitude faster.
  */
@@ -75,16 +77,65 @@ class HashedMap
             int  next = -1;  /**< Index in the list of the next element with the same hash, -1 if none */
         };
 
+        /*! \brief The table size is set to at least this factor time the nr of keys */
+        static constexpr float c_relTableSizeSetMin       = 1.5;
+        /*! \brief Threshold for increasing the table size */
+        static constexpr float c_relTableSizeThresholdMin = 1.3;
+        /*! \brief Threshold for decreasing the table size */
+        static constexpr float c_relTableSizeThresholdMax = 3.5;
+
+        /*! \brief Resizes the table
+         *
+         * \param[in] numElementsEstimate  An estimate of the number of elements that will be stored
+         */
+        void resize(int numElementsEstimate)
+        {
+            GMX_RELEASE_ASSERT(numElements_ == 0, "Table needs to be empty for resize");
+
+            /* The fraction of table entries with 0   size lists is e^-f.
+             * The fraction of table entries with >=1 size lists is 1 - e^-f
+             * where f is: the #elements / tableSize
+             * The fraction of elements not in the direct list is: 1 - (1 - e^-f)/f.
+             * Thus the optimal table size is roughly double #elements.
+             */
+            /* Make the hash table a power of 2 and at least 1.5 * #elements */
+            int tableSize = 64;
+            while (tableSize <= INT_MAX/2 &&
+                   numElementsEstimate*c_relTableSizeSetMin > tableSize)
+            {
+                tableSize *= 2;
+            }
+            table_.resize(tableSize);
+
+            /* Table size is a power of 2, so a binary mask gives the hash */
+            bitMask_                        = tableSize - 1;
+            startIndexForSpaceForListEntry_ = tableSize;
+        }
+
     public:
         /*! \brief Constructor
          *
-         * \param[in] baseTableSize  The size of the base table, optimal is around twice the number of expected entries
+         * \param[in] numElementsEstimate  An estimate of the number of elements that will be stored, used for optimizing initial performance
+         *
+         * Note that the estimate of the number of elements is only relevant
+         * for the performance up until the first call to clear(), after which
+         * table size is optimized based on the actual number of elements.
          */
-        HashedMap(int baseTableSize) :
-            table_(baseTableSize),
-            mod_(baseTableSize),
-            startSpaceSearch_(baseTableSize)
+        HashedMap(int numElementsEstimate)
         {
+            resize(numElementsEstimate);
+        }
+
+        /*! \brief Returns the number of elements */
+        int size() const
+        {
+            return numElements_;
+        }
+
+        /*! \brief Returns the number of buckets, i.e. the number of possible hashes */
+        int bucket_count() const
+        {
+            return bitMask_ + 1;
         }
 
         /*! \brief Inserts entry, key should not already be present
@@ -96,15 +147,7 @@ class HashedMap
         void insert(int      key,
                     const T &value)
         {
-            // Note: This is performance critical, so we only throw in debug mode
-#ifndef NDEBUG
-            if (key < 0)
-            {
-                GMX_THROW(InvalidInputError("Invalid key value"));
-            }
-#endif
-
-            size_t ind = key % mod_;
+            size_t ind = (key & bitMask_);
 
             if (table_[ind].key >= 0)
             {
@@ -112,6 +155,7 @@ class HashedMap
                  * If we find the matching key, return the value.
                  */
                 int ind_prev = ind;
+// Note: This is performance critical, so we only throw in debug mode
 #ifndef NDEBUG
                 if (table_[ind_prev].key == key)
                 {
@@ -128,8 +172,8 @@ class HashedMap
                     }
 #endif
                 }
-                /* Search for space in the array */
-                ind = startSpaceSearch_;
+                /* Search for space in table_ */
+                ind = startIndexForSpaceForListEntry_;
                 while (ind < table_.size() && table_[ind].key >= 0)
                 {
                     ind++;
@@ -139,23 +183,25 @@ class HashedMap
                 {
                     table_.resize(table_.size() + 1);
                 }
-                table_[ind_prev].next = ind;
+                table_[ind_prev].next           = ind;
 
-                startSpaceSearch_ = ind + 1;
+                startIndexForSpaceForListEntry_ = ind + 1;
             }
 
-            table_[ind].key   = key;
-            table_[ind].value = value;
+            table_[ind].key    = key;
+            table_[ind].value  = value;
+
+            numElements_      += 1;
         }
 
-        /*! \brief Delete the entry for key \p key
+        /*! \brief Delete the entry for key \p key, when present
          *
          * \param[in] key  The key
          */
         void erase(int key)
         {
             int ind_prev = -1;
-            int ind      = key % mod_;
+            int ind      = (key & bitMask_);
             do
             {
                 if (table_[ind].key == key)
@@ -167,13 +213,15 @@ class HashedMap
                         /* This index is a linked entry, so we free an entry.
                          * Check if we are creating the first empty space.
                          */
-                        if (ind < startSpaceSearch_)
+                        if (ind < startIndexForSpaceForListEntry_)
                         {
-                            startSpaceSearch_ = ind;
+                            startIndexForSpaceForListEntry_ = ind;
                         }
                     }
-                    table_[ind].key  = -1;
-                    table_[ind].next = -1;
+                    table_[ind].key   = -1;
+                    table_[ind].next  = -1;
+
+                    numElements_     -= 1;
 
                     return;
                 }
@@ -197,7 +245,7 @@ class HashedMap
          */
         const T *find(int key) const
         {
-            int ind = key % mod_;
+            int ind = (key & bitMask_);
             do
             {
                 if (table_[ind].key == key)
@@ -211,21 +259,43 @@ class HashedMap
             return nullptr;
         }
 
-        /*! \brief Clear all the entries in the list */
+        /*! \brief Clear all the entries in the list
+         *
+         * Also optimizes the size of the table based on the current
+         * number of elements stored.
+         */
         void clear()
         {
+            const int oldNumElements = numElements_;
+
             for (hashEntry &entry : table_)
             {
                 entry.key  = -1;
                 entry.next = -1;
             }
-            startSpaceSearch_ = mod_;
+            startIndexForSpaceForListEntry_ = bucket_count();
+            numElements_                    = 0;
+
+            /* Resize the hash table when the occupation is far from optimal.
+             * Do not resize with 0 elements to avoid minimal size when clear()
+             * is called twice in a row.
+             */
+            if (oldNumElements > 0 && (oldNumElements*c_relTableSizeThresholdMax < bucket_count() ||
+                                       oldNumElements*c_relTableSizeThresholdMin > bucket_count()))
+            {
+                resize(oldNumElements);
+            }
         }
 
     private:
-        std::vector<hashEntry> table_;            /**< The hash table list */
-        int                    mod_;              /**< The hash size */
-        int                    startSpaceSearch_; /**< Index in lal at which to start looking for empty space */
+        /*! \brief The hash table list */
+        std::vector<hashEntry> table_;
+        /*! \brief The bit mask for computing the hash of a key */
+        int                    bitMask_                        = 0;
+        /*! \brief Index in table_ at which to start looking for empty space for a new linked list entry */
+        int                    startIndexForSpaceForListEntry_ = 0;
+        /*! \brief The number of elements currently stored in the table */
+        int                    numElements_                    = 0;
 };
 
 } // namespace gmx
index bfc8fd3c991658cd900e01f70d637afdcf77a17d..d84725b8ab0a2c48d568e8f39664f969f4948d87 100644 (file)
@@ -85,6 +85,19 @@ TEST(HashedMap, InsertsFinds)
     checkDoesNotFind(map, 4);
 }
 
+TEST(HashedMap, NegativeKeysWork)
+{
+    gmx::HashedMap<char> map(5);
+
+    map.insert(-1, 'a');
+    map.insert(1,  'b');
+    map.insert(-3, 'c');
+
+    checkFinds(map, -1, 'a');
+    checkFinds(map, 1,  'b');
+    checkFinds(map, -3, 'c');
+}
+
 TEST(HashedMap, InsertsErases)
 {
     gmx::HashedMap<char> map(3);
@@ -112,16 +125,35 @@ TEST(HashedMap, Clears)
     checkDoesNotFind(map, 7);
 }
 
-// HashedMap only throws in debug mode, so only test in debug mode
-#ifndef NDEBUG
-
-TEST(HashedMap, CatchesInvalidKey)
+// Check that entries with the same hash are handled correctly
+TEST(HashedMap, LinkedEntries)
 {
-    gmx::HashedMap<char> map(101);
+    // HashedMap uses bit masking, so keys that differ by exactly
+    // a power of 2 larger than the table size will have the same hash
 
-    EXPECT_THROW_GMX(map.insert(-1, 'a'), gmx::InvalidInputError);
+    gmx::HashedMap<char> map(20);
+
+    const int            largePowerOf2 = 2048;
+
+    map.insert(3 + 0*largePowerOf2, 'a');
+    map.insert(3 + 1*largePowerOf2, 'b');
+    map.insert(3 + 2*largePowerOf2, 'c');
+
+    checkFinds(map, 3 + 0*largePowerOf2, 'a');
+    checkFinds(map, 3 + 1*largePowerOf2, 'b');
+    checkFinds(map, 3 + 2*largePowerOf2, 'c');
+
+    // Erase the middle entry in the linked list
+    map.erase(3 + 1*largePowerOf2);
+
+    checkFinds(map, 3 + 0*largePowerOf2, 'a');
+    checkDoesNotFind(map, 3 + 1*largePowerOf2);
+    checkFinds(map, 3 + 2*largePowerOf2, 'c');
 }
 
+// HashedMap only throws in debug mode, so only test in debug mode
+#ifndef NDEBUG
+
 TEST(HashedMap, CatchesDuplicateKey)
 {
     gmx::HashedMap<char> map(15);
@@ -133,4 +165,34 @@ TEST(HashedMap, CatchesDuplicateKey)
 
 #endif // NDEBUG
 
+// Check the table is resized after clear()
+TEST(HashedMap, ResizesTable)
+{
+    gmx::HashedMap<char> map(1);
+
+    // This test assumes the minimum bucket count is 64 or less
+    EXPECT_LT(map.bucket_count(), 128);
+
+    for (int i = 0; i < 60; i++)
+    {
+        map.insert(2*i + 3, 'a');
+    }
+    EXPECT_LT(map.bucket_count(), 128);
+
+    // Check that the table size is double #elements after clear()
+    map.clear();
+    EXPECT_EQ(map.bucket_count(), 128);
+
+    // Check that calling clear() a second time does not resize
+    map.clear();
+    EXPECT_EQ(map.bucket_count(), 128);
+
+    map.insert(2, 'b');
+    EXPECT_EQ(map.bucket_count(), 128);
+
+    // Check that calling clear with 1 elements sizes down
+    map.clear();
+    EXPECT_LT(map.bucket_count(), 128);
+}
+
 }      // namespace