Replace nbnxn_buffer_flags_t with vector
[alexxy/gromacs.git] / src / gromacs / nbnxm / pairlist.cpp
index 66ec8689d9b98c0ed8bba0d16f9aabbb664753b9..4365b02a0842075a86ef511d63aac1226f22c7d8 100644 (file)
@@ -239,20 +239,19 @@ void nbnxn_init_pairlist_fep(t_nblist* nl)
     nl->excl_fep = nullptr;
 }
 
-static void init_buffer_flags(nbnxn_buffer_flags_t* flags, int natoms)
+static constexpr int sizeNeededForBufferFlags(const int numAtoms)
 {
-    flags->nflag = (natoms + NBNXN_BUFFERFLAG_SIZE - 1) / NBNXN_BUFFERFLAG_SIZE;
-    if (flags->nflag > flags->flag_nalloc)
-    {
-        flags->flag_nalloc = over_alloc_large(flags->nflag);
-        srenew(flags->flag, flags->flag_nalloc);
-    }
-    for (int b = 0; b < flags->nflag; b++)
-    {
-        bitmask_clear(&(flags->flag[b]));
-    }
+    return (numAtoms + NBNXN_BUFFERFLAG_SIZE - 1) / NBNXN_BUFFERFLAG_SIZE;
 }
 
+// Resets current flags to 0 and adds more flags if needed.
+static void resizeAndZeroBufferFlags(std::vector<gmx_bitmask_t>* flags, const int numAtoms)
+{
+    flags->clear();
+    flags->resize(sizeNeededForBufferFlags(numAtoms), 0);
+}
+
+
 /* Returns the pair-list cutoff between a bounding box and a grid cell given an atom-to-atom pair-list cutoff
  *
  * Given a cutoff distance between atoms, this functions returns the cutoff
@@ -3126,7 +3125,7 @@ static void nbnxn_make_pairlist_part(const Nbnxm::GridSet&   gridSet,
         gridi_flag_shift = getBufferFlagShift(nbl->na_ci);
         gridj_flag_shift = getBufferFlagShift(nbl->na_cj);
 
-        gridj_flag = work->buffer_flags.flag;
+        gridj_flag = work->buffer_flags.data();
     }
 
     gridSet.getBox(box);
@@ -3560,7 +3559,7 @@ static void nbnxn_make_pairlist_part(const Nbnxm::GridSet&   gridSet,
 
         if (bFBufferFlag && getNumSimpleJClustersInList(*nbl) > ncj_old_i)
         {
-            bitmask_init_bit(&(work->buffer_flags.flag[(iGrid.cellOffset() + ci) >> gridi_flag_shift]), th);
+            bitmask_init_bit(&(work->buffer_flags[(iGrid.cellOffset() + ci) >> gridi_flag_shift]), th);
         }
     }
 
@@ -3583,20 +3582,21 @@ static void nbnxn_make_pairlist_part(const Nbnxm::GridSet&   gridSet,
 
 static void reduce_buffer_flags(gmx::ArrayRef<PairsearchWork> searchWork,
                                 int                           nsrc,
-                                const nbnxn_buffer_flags_t*   dest)
+                                gmx::ArrayRef<gmx_bitmask_t>  dest)
 {
     for (int s = 0; s < nsrc; s++)
     {
-        gmx_bitmask_t* flag = searchWork[s].buffer_flags.flag;
+        gmx::ArrayRef<gmx_bitmask_t> flags(searchWork[s].buffer_flags);
 
-        for (int b = 0; b < dest->nflag; b++)
+        for (size_t b = 0; b < dest.size(); b++)
         {
-            bitmask_union(&(dest->flag[b]), flag[b]);
+            gmx_bitmask_t& flag = dest[b];
+            bitmask_union(&flag, flags[b]);
         }
     }
 }
 
-static void print_reduction_cost(const nbnxn_buffer_flags_t* flags, int nout)
+static void print_reduction_cost(gmx::ArrayRef<const gmx_bitmask_t> flags, int nout)
 {
     int           nelem, nkeep, ncopy, nred, out;
     gmx_bitmask_t mask_0;
@@ -3606,20 +3606,20 @@ static void print_reduction_cost(const nbnxn_buffer_flags_t* flags, int nout)
     ncopy = 0;
     nred  = 0;
     bitmask_init_bit(&mask_0, 0);
-    for (int b = 0; b < flags->nflag; b++)
+    for (const gmx_bitmask_t& flag_mask : flags)
     {
-        if (bitmask_is_equal(flags->flag[b], mask_0))
+        if (bitmask_is_equal(flag_mask, mask_0))
         {
             /* Only flag 0 is set, no copy of reduction required */
             nelem++;
             nkeep++;
         }
-        else if (!bitmask_is_zero(flags->flag[b]))
+        else if (!bitmask_is_zero(flag_mask))
         {
             int c = 0;
             for (out = 0; out < nout; out++)
             {
-                if (bitmask_is_set(flags->flag[b], out))
+                if (bitmask_is_set(flag_mask, out))
                 {
                     c++;
                 }
@@ -3635,12 +3635,10 @@ static void print_reduction_cost(const nbnxn_buffer_flags_t* flags, int nout)
             }
         }
     }
-
+    const auto numFlags = static_cast<double>(flags.size());
     fprintf(debug,
-            "nbnxn reduction: #flag %d #list %d elem %4.2f, keep %4.2f copy %4.2f red %4.2f\n",
-            flags->nflag, nout, nelem / static_cast<double>(flags->nflag),
-            nkeep / static_cast<double>(flags->nflag), ncopy / static_cast<double>(flags->nflag),
-            nred / static_cast<double>(flags->nflag));
+            "nbnxn reduction: #flag %lu #list %d elem %4.2f, keep %4.2f copy %4.2f red %4.2f\n",
+            flags.size(), nout, nelem / numFlags, nkeep / numFlags, ncopy / numFlags, nred / numFlags);
 }
 
 /* Copies the list entries from src to dest when cjStart <= *cjGlobal < cjEnd.
@@ -3740,7 +3738,7 @@ static void rebalanceSimpleLists(gmx::ArrayRef<const NbnxnPairlistCpu> srcSet,
         /* Note that the flags in the work struct (still) contain flags
          * for all entries that are present in srcSet->nbl[t].
          */
-        gmx_bitmask_t* flag = searchWork[t].buffer_flags.flag;
+        gmx_bitmask_t* flag = &searchWork[t].buffer_flags[0];
 
         int iFlagShift = getBufferFlagShift(dest.na_ci);
         int jFlagShift = getBufferFlagShift(dest.na_cj);
@@ -3946,7 +3944,7 @@ void PairlistSet::constructPairlists(const Nbnxm::GridSet&         gridSet,
     /* We should re-init the flags before making the first list */
     if (nbat->bUseBufferFlags && locality_ == InteractionLocality::Local)
     {
-        init_buffer_flags(&nbat->buffer_flags, nbat->numAtoms());
+        resizeAndZeroBufferFlags(&nbat->buffer_flags, nbat->numAtoms());
     }
 
     if (!isCpuType_ && minimumIlistCountForGpuBalancing > 0)
@@ -4016,7 +4014,7 @@ void PairlistSet::constructPairlists(const Nbnxm::GridSet&         gridSet,
                      */
                     if (nbat->bUseBufferFlags && (iZone == 0 && jZone == 0))
                     {
-                        init_buffer_flags(&searchWork[th].buffer_flags, nbat->numAtoms());
+                        resizeAndZeroBufferFlags(&searchWork[th].buffer_flags, nbat->numAtoms());
                     }
 
                     if (combineLists_ && th > 0)
@@ -4133,7 +4131,7 @@ void PairlistSet::constructPairlists(const Nbnxm::GridSet&         gridSet,
 
     if (nbat->bUseBufferFlags)
     {
-        reduce_buffer_flags(searchWork, numLists, &nbat->buffer_flags);
+        reduce_buffer_flags(searchWork, numLists, nbat->buffer_flags);
     }
 
     if (gridSet.haveFep())
@@ -4187,7 +4185,7 @@ void PairlistSet::constructPairlists(const Nbnxm::GridSet&         gridSet,
 
         if (nbat->bUseBufferFlags)
         {
-            print_reduction_cost(&nbat->buffer_flags, numLists);
+            print_reduction_cost(nbat->buffer_flags, numLists);
         }
     }