Change nbnxn_search to class PairSearch
[alexxy/gromacs.git] / src / gromacs / nbnxm / pairlist.cpp
index 1a4ec994c0c15d42f624b4f9bd8513f56d80be58..d7b4e483a17799fd0ae4e6959de8c6cbadf1ac0a 100644 (file)
@@ -91,41 +91,27 @@ using InteractionLocality = Nbnxm::InteractionLocality;
 constexpr bool c_pbcShiftBackward = true;
 
 
-static void nbs_cycle_clear(nbnxn_cycle_t *cc)
-{
-    for (int i = 0; i < enbsCCnr; i++)
-    {
-        cc[i].count = 0;
-        cc[i].c     = 0;
-    }
-}
-
-static double Mcyc_av(const nbnxn_cycle_t *cc)
-{
-    return static_cast<double>(cc->c)*1e-6/cc->count;
-}
-
-static void nbs_cycle_print(FILE *fp, const nbnxn_search *nbs)
+void PairSearch::SearchCycleCounting::printCycles(FILE                               *fp,
+                                                  gmx::ArrayRef<const PairsearchWork> work) const
 {
     fprintf(fp, "\n");
-    fprintf(fp, "ns %4d grid %4.1f search %4.1f red.f %5.3f",
-            nbs->cc[enbsCCgrid].count,
-            Mcyc_av(&nbs->cc[enbsCCgrid]),
-            Mcyc_av(&nbs->cc[enbsCCsearch]),
-            Mcyc_av(&nbs->cc[enbsCCreducef]));
+    fprintf(fp, "ns %4d grid %4.1f search %4.1f",
+            cc_[enbsCCgrid].count(),
+            cc_[enbsCCgrid].averageMCycles(),
+            cc_[enbsCCsearch].averageMCycles());
 
-    if (nbs->work.size() > 1)
+    if (work.size() > 1)
     {
-        if (nbs->cc[enbsCCcombine].count > 0)
+        if (cc_[enbsCCcombine].count() > 0)
         {
             fprintf(fp, " comb %5.2f",
-                    Mcyc_av(&nbs->cc[enbsCCcombine]));
+                    cc_[enbsCCcombine].averageMCycles());
         }
         fprintf(fp, " s. th");
-        for (const nbnxn_search_work_t &work : nbs->work)
+        for (const PairsearchWork &workEntry : work)
         {
             fprintf(fp, " %4.1f",
-                    Mcyc_av(&work.cc[enbsCCsearch]));
+                    workEntry.cycleCounter.averageMCycles());
         }
     }
     fprintf(fp, "\n");
@@ -279,7 +265,7 @@ static void free_nblist(t_nblist *nl)
     sfree(nl->excl_fep);
 }
 
-nbnxn_search_work_t::nbnxn_search_work_t() :
+PairsearchWork::PairsearchWork() :
     cp0({{0}}
         ),
     buffer_flags({0, nullptr, 0}),
@@ -288,20 +274,19 @@ nbnxn_search_work_t::nbnxn_search_work_t() :
     cp1({{0}})
 {
     nbnxn_init_pairlist_fep(nbl_fep.get());
-
-    nbs_cycle_clear(cc);
 }
 
-nbnxn_search_work_t::~nbnxn_search_work_t()
+PairsearchWork::~PairsearchWork()
 {
     sfree(buffer_flags.flag);
 
     free_nblist(nbl_fep.get());
 }
 
-nbnxn_search::DomainSetup::DomainSetup(const int                 ePBC,
-                                       const ivec               *numDDCells,
-                                       const gmx_domdec_zones_t *ddZones) :
+// TODO: Move to pairsearch.cpp
+PairSearch::DomainSetup::DomainSetup(const int                 ePBC,
+                                     const ivec               *numDDCells,
+                                     const gmx_domdec_zones_t *ddZones) :
     ePBC(ePBC),
     haveDomDec(numDDCells != nullptr),
     zones(ddZones)
@@ -312,19 +297,18 @@ nbnxn_search::DomainSetup::DomainSetup(const int                 ePBC,
     }
 }
 
-nbnxn_search::nbnxn_search(const int                 ePBC,
-                           const ivec               *numDDCells,
-                           const gmx_domdec_zones_t *ddZones,
-                           const PairlistType        pairlistType,
-                           const bool                haveFep,
-                           const int                 maxNumThreads) :
+// TODO: Move to pairsearch.cpp
+PairSearch::PairSearch(const int                 ePBC,
+                       const ivec               *numDDCells,
+                       const gmx_domdec_zones_t *ddZones,
+                       const PairlistType        pairlistType,
+                       const bool                haveFep,
+                       const int                 maxNumThreads) :
     domainSetup_(ePBC, numDDCells, ddZones),
     gridSet_(domainSetup_.haveDomDecPerDim, pairlistType, haveFep, maxNumThreads),
-    search_count(0),
-    work(maxNumThreads)
+    work_(maxNumThreads)
 {
-    print_cycles = (getenv("GMX_NBNXN_CYCLE") != nullptr);
-    nbs_cycle_clear(cc);
+    cycleCounting_.recordCycles_ = (getenv("GMX_NBNXN_CYCLE") != nullptr);
 }
 
 static void init_buffer_flags(nbnxn_buffer_flags_t *flags,
@@ -853,10 +837,12 @@ void nbnxn_init_pairlist_set(nbnxn_pairlist_set_t *nbl_list)
 }
 
 /* Print statistics of a pair list, used for debug output */
-static void print_nblist_statistics(FILE *fp, const NbnxnPairlistCpu *nbl,
-                                    const nbnxn_search *nbs, real rl)
+static void print_nblist_statistics(FILE                   *fp,
+                                    const NbnxnPairlistCpu *nbl,
+                                    const PairSearch       &pairSearch,
+                                    const real              rl)
 {
-    const Grid             &grid = nbs->gridSet().grids()[0];
+    const Grid             &grid = pairSearch.gridSet().grids()[0];
     const Grid::Dimensions &dims = grid.dimensions();
 
     fprintf(fp, "nbl nci %zu ncj %d\n",
@@ -898,10 +884,12 @@ static void print_nblist_statistics(FILE *fp, const NbnxnPairlistCpu *nbl,
 }
 
 /* Print statistics of a pair lists, used for debug output */
-static void print_nblist_statistics(FILE *fp, const NbnxnPairlistGpu *nbl,
-                                    const nbnxn_search *nbs, real rl)
+static void print_nblist_statistics(FILE                   *fp,
+                                    const NbnxnPairlistGpu *nbl,
+                                    const PairSearch       &pairSearch,
+                                    const real              rl)
 {
-    const Grid             &grid = nbs->gridSet().grids()[0];
+    const Grid             &grid = pairSearch.gridSet().grids()[0];
     const Grid::Dimensions &dims = grid.dimensions();
 
     fprintf(fp, "nbl nsci %zu ncj4 %zu nsi %d excl4 %zu\n",
@@ -2590,7 +2578,7 @@ static real nonlocal_vol2(const struct gmx_domdec_zones_t *zones, const rvec ls,
 }
 
 /* Estimates the average size of a full j-list for super/sub setup */
-static void get_nsubpair_target(const nbnxn_search        *nbs,
+static void get_nsubpair_target(const PairSearch          &pairSearch,
                                 const InteractionLocality  iloc,
                                 const real                 rlist,
                                 const int                  min_ci_balanced,
@@ -2603,7 +2591,7 @@ static void get_nsubpair_target(const nbnxn_search        *nbs,
     const int           nsubpair_target_min = 36;
     real                r_eff_sup, vol_est, nsp_est, nsp_est_nl;
 
-    const Grid         &grid = nbs->gridSet().grids()[0];
+    const Grid         &grid = pairSearch.gridSet().grids()[0];
 
     /* We don't need to balance list sizes if:
      * - We didn't request balancing.
@@ -2631,7 +2619,8 @@ static void get_nsubpair_target(const nbnxn_search        *nbs,
     /* The formulas below are a heuristic estimate of the average nsj per si*/
     r_eff_sup = rlist + nbnxn_get_rlist_effective_inc(numAtomsCluster, ls);
 
-    if (!nbs->domainSetup().haveDomDec || nbs->domainSetup().zones->n == 1)
+    if (!pairSearch.domainSetup().haveDomDec ||
+        pairSearch.domainSetup().zones->n == 1)
     {
         nsp_est_nl = 0;
     }
@@ -2639,7 +2628,7 @@ static void get_nsubpair_target(const nbnxn_search        *nbs,
     {
         nsp_est_nl =
             gmx::square(dims.atomDensity/numAtomsCluster)*
-            nonlocal_vol2(nbs->domainSetup().zones, ls, r_eff_sup);
+            nonlocal_vol2(pairSearch.domainSetup().zones, ls, r_eff_sup);
     }
 
     if (iloc == InteractionLocality::Local)
@@ -2825,8 +2814,8 @@ static void combine_nblists(int nnbl, NbnxnPairlistGpu **nbl,
     }
 }
 
-static void balance_fep_lists(const nbnxn_search   *nbs,
-                              nbnxn_pairlist_set_t *nbl_lists)
+static void balance_fep_lists(gmx::ArrayRef<PairsearchWork>       work,
+                              nbnxn_pairlist_set_t               *nbl_lists)
 {
     int       nnbl;
     int       nri_tot, nrj_tot, nrj_target;
@@ -2859,7 +2848,7 @@ static void balance_fep_lists(const nbnxn_search   *nbs,
     {
         try
         {
-            t_nblist *nbl = nbs->work[th].nbl_fep.get();
+            t_nblist *nbl = work[th].nbl_fep.get();
 
             /* Note that here we allocate for the total size, instead of
              * a per-thread esimate (which is hard to obtain).
@@ -2883,7 +2872,7 @@ static void balance_fep_lists(const nbnxn_search   *nbs,
 
     /* Loop over the source lists and assign and copy i-entries */
     th_dest = 0;
-    nbld    = nbs->work[th_dest].nbl_fep.get();
+    nbld    = work[th_dest].nbl_fep.get();
     for (int th = 0; th < nnbl; th++)
     {
         t_nblist *nbls;
@@ -2904,7 +2893,7 @@ static void balance_fep_lists(const nbnxn_search   *nbs,
                 nbld->nrj + nrj - nrj_target > nrj_target - nbld->nrj)
             {
                 th_dest++;
-                nbld = nbs->work[th_dest].nbl_fep.get();
+                nbld = work[th_dest].nbl_fep.get();
             }
 
             nbld->iinr[nbld->nri]  = nbls->iinr[i];
@@ -2925,8 +2914,8 @@ static void balance_fep_lists(const nbnxn_search   *nbs,
     /* Swap the list pointers */
     for (int th = 0; th < nnbl; th++)
     {
-        t_nblist *nbl_tmp      = nbs->work[th].nbl_fep.release();
-        nbs->work[th].nbl_fep.reset(nbl_lists->nbl_fep[th]);
+        t_nblist *nbl_tmp      = work[th].nbl_fep.release();
+        work[th].nbl_fep.reset(nbl_lists->nbl_fep[th]);
         nbl_lists->nbl_fep[th] = nbl_tmp;
 
         if (debug)
@@ -3229,10 +3218,10 @@ static void setBufferFlags(const NbnxnPairlistGpu gmx_unused &nbl,
 
 /* Generates the part of pair-list nbl assigned to our thread */
 template <typename T>
-static void nbnxn_make_pairlist_part(const nbnxn_search *nbs,
+static void nbnxn_make_pairlist_part(const PairSearch &pairSearch,
                                      const Grid &iGrid,
                                      const Grid &jGrid,
-                                     nbnxn_search_work_t *work,
+                                     PairsearchWork *work,
                                      const nbnxn_atomdata_t *nbat,
                                      const t_blocka &exclusions,
                                      real rlist,
@@ -3261,8 +3250,6 @@ static void nbnxn_make_pairlist_part(const nbnxn_search *nbs,
     gmx_bitmask_t    *gridj_flag       = nullptr;
     int               ncj_old_i, ncj_old_j;
 
-    nbs_cycle_start(&work->cc[enbsCCsearch]);
-
     if (jGrid.geometry().isSimple != pairlistIsSimple(*nbl) ||
         iGrid.geometry().isSimple != pairlistIsSimple(*nbl))
     {
@@ -3286,7 +3273,7 @@ static void nbnxn_make_pairlist_part(const nbnxn_search *nbs,
         gridj_flag       = work->buffer_flags.flag;
     }
 
-    const Nbnxm::GridSet &gridSet = nbs->gridSet();
+    const Nbnxm::GridSet &gridSet = pairSearch.gridSet();
 
     gridSet.getBox(box);
 
@@ -3329,8 +3316,8 @@ static void nbnxn_make_pairlist_part(const nbnxn_search *nbs,
         /* Check if we need periodicity shifts.
          * Without PBC or with domain decomposition we don't need them.
          */
-        if (d >= ePBC2npbcdim(nbs->domainSetup().ePBC) ||
-            nbs->domainSetup().haveDomDecPerDim[d])
+        if (d >= ePBC2npbcdim(pairSearch.domainSetup().ePBC) ||
+            pairSearch.domainSetup().haveDomDecPerDim[d])
         {
             shp[d] = 0;
         }
@@ -3735,15 +3722,13 @@ static void nbnxn_make_pairlist_part(const nbnxn_search *nbs,
 
     work->ndistc = numDistanceChecks;
 
-    nbs_cycle_stop(&work->cc[enbsCCsearch]);
-
     checkListSizeConsistency(*nbl, haveFep);
 
     if (debug)
     {
         fprintf(debug, "number of distance checks %d\n", numDistanceChecks);
 
-        print_nblist_statistics(debug, nbl, nbs, rlist);
+        print_nblist_statistics(debug, nbl, pairSearch, rlist);
 
         if (haveFep)
         {
@@ -3752,13 +3737,13 @@ static void nbnxn_make_pairlist_part(const nbnxn_search *nbs,
     }
 }
 
-static void reduce_buffer_flags(const nbnxn_search         *nbs,
+static void reduce_buffer_flags(const PairSearch           &pairSearch,
                                 int                         nsrc,
                                 const nbnxn_buffer_flags_t *dest)
 {
     for (int s = 0; s < nsrc; s++)
     {
-        gmx_bitmask_t * flag = nbs->work[s].buffer_flags.flag;
+        gmx_bitmask_t * flag = pairSearch.work()[s].buffer_flags.flag;
 
         for (int b = 0; b < dest->nflag; b++)
         {
@@ -3864,7 +3849,7 @@ static void copySelectedListRange(const nbnxn_ci_t * gmx_restrict srcCi,
 static void rebalanceSimpleLists(int                                  numLists,
                                  NbnxnPairlistCpu * const * const     srcSet,
                                  NbnxnPairlistCpu                   **destSet,
-                                 gmx::ArrayRef<nbnxn_search_work_t>   searchWork)
+                                 gmx::ArrayRef<PairsearchWork>        searchWork)
 {
     int ncjTotal = 0;
     for (int s = 0; s < numLists; s++)
@@ -4032,7 +4017,7 @@ static void sort_sci(NbnxnPairlistGpu *nbl)
 
 void
 nonbonded_verlet_t::PairlistSets::construct(const InteractionLocality  iLocality,
-                                            nbnxn_search              *nbs,
+                                            PairSearch                *pairSearch,
                                             nbnxn_atomdata_t          *nbat,
                                             const t_blocka            *excl,
                                             const Nbnxm::KernelType    kernelType,
@@ -4074,12 +4059,12 @@ nonbonded_verlet_t::PairlistSets::construct(const InteractionLocality  iLocality
     }
     else
     {
-        nzi = nbs->domainSetup().zones->nizone;
+        nzi = pairSearch->domainSetup().zones->nizone;
     }
 
     if (!nbl_list->bSimple && minimumIlistCountForGpuBalancing_ > 0)
     {
-        get_nsubpair_target(nbs, iLocality, rlist, minimumIlistCountForGpuBalancing_,
+        get_nsubpair_target(*pairSearch, iLocality, rlist, minimumIlistCountForGpuBalancing_,
                             &nsubpair_target, &nsubpair_tot_est);
     }
     else
@@ -4100,17 +4085,17 @@ nonbonded_verlet_t::PairlistSets::construct(const InteractionLocality  iLocality
             clear_pairlist(nbl_list->nblGpu[th]);
         }
 
-        if (nbs->gridSet().haveFep())
+        if (pairSearch->gridSet().haveFep())
         {
             clear_pairlist_fep(nbl_list->nbl_fep[th]);
         }
     }
 
-    const gmx_domdec_zones_t *ddZones = nbs->domainSetup().zones;
+    const gmx_domdec_zones_t *ddZones = pairSearch->domainSetup().zones;
 
     for (int zi = 0; zi < nzi; zi++)
     {
-        const Grid &iGrid = nbs->gridSet().grids()[zi];
+        const Grid &iGrid = pairSearch->gridSet().grids()[zi];
 
         int                 zj0;
         int                 zj1;
@@ -4130,16 +4115,16 @@ nonbonded_verlet_t::PairlistSets::construct(const InteractionLocality  iLocality
         }
         for (int zj = zj0; zj < zj1; zj++)
         {
-            const Grid &jGrid = nbs->gridSet().grids()[zj];
+            const Grid &jGrid = pairSearch->gridSet().grids()[zj];
 
             if (debug)
             {
                 fprintf(debug, "ns search grid %d vs %d\n", zi, zj);
             }
 
-            nbs_cycle_start(&nbs->cc[enbsCCsearch]);
+            pairSearch->cycleCounting_.start(PairSearch::enbsCCsearch);
 
-            ci_block = get_ci_block_size(iGrid, nbs->domainSetup().haveDomDec, nnbl);
+            ci_block = get_ci_block_size(iGrid, pairSearch->domainSetup().haveDomDec, nnbl);
 
             /* With GPU: generate progressively smaller lists for
              * load balancing for local only or non-local with 2 zones.
@@ -4156,7 +4141,7 @@ nonbonded_verlet_t::PairlistSets::construct(const InteractionLocality  iLocality
                      */
                     if (nbat->bUseBufferFlags && ((zi == 0 && zj == 0)))
                     {
-                        init_buffer_flags(&nbs->work[th].buffer_flags, nbat->numAtoms());
+                        init_buffer_flags(&pairSearch->work()[th].buffer_flags, nbat->numAtoms());
                     }
 
                     if (CombineNBLists && th > 0)
@@ -4166,11 +4151,15 @@ nonbonded_verlet_t::PairlistSets::construct(const InteractionLocality  iLocality
                         clear_pairlist(nbl_list->nblGpu[th]);
                     }
 
+                    auto &searchWork = pairSearch->work()[th];
+
+                    searchWork.cycleCounter.start();
+
                     /* Divide the i super cell equally over the nblists */
                     if (nbl_list->bSimple)
                     {
-                        nbnxn_make_pairlist_part(nbs, iGrid, jGrid,
-                                                 &nbs->work[th], nbat, *excl,
+                        nbnxn_make_pairlist_part(*pairSearch, iGrid, jGrid,
+                                                 &searchWork, nbat, *excl,
                                                  rlist,
                                                  kernelType,
                                                  ci_block,
@@ -4183,8 +4172,8 @@ nonbonded_verlet_t::PairlistSets::construct(const InteractionLocality  iLocality
                     }
                     else
                     {
-                        nbnxn_make_pairlist_part(nbs, iGrid, jGrid,
-                                                 &nbs->work[th], nbat, *excl,
+                        nbnxn_make_pairlist_part(*pairSearch, iGrid, jGrid,
+                                                 &searchWork, nbat, *excl,
                                                  rlist,
                                                  kernelType,
                                                  ci_block,
@@ -4195,17 +4184,19 @@ nonbonded_verlet_t::PairlistSets::construct(const InteractionLocality  iLocality
                                                  nbl_list->nblGpu[th],
                                                  nbl_list->nbl_fep[th]);
                     }
+
+                    searchWork.cycleCounter.stop();
                 }
                 GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
             }
-            nbs_cycle_stop(&nbs->cc[enbsCCsearch]);
+            pairSearch->cycleCounting_.stop(PairSearch::enbsCCsearch);
 
             np_tot = 0;
             np_noq = 0;
             np_hlj = 0;
             for (int th = 0; th < nnbl; th++)
             {
-                inc_nrnb(nrnb, eNR_NBNXN_DIST2, nbs->work[th].ndistc);
+                inc_nrnb(nrnb, eNR_NBNXN_DIST2, pairSearch->work()[th].ndistc);
 
                 if (nbl_list->bSimple)
                 {
@@ -4238,11 +4229,11 @@ nonbonded_verlet_t::PairlistSets::construct(const InteractionLocality  iLocality
                 GMX_ASSERT(!nbl_list->bSimple, "Can only combine GPU lists");
                 NbnxnPairlistGpu **nbl = nbl_list->nblGpu;
 
-                nbs_cycle_start(&nbs->cc[enbsCCcombine]);
+                pairSearch->cycleCounting_.start(PairSearch::enbsCCcombine);
 
                 combine_nblists(nnbl-1, nbl+1, nbl[0]);
 
-                nbs_cycle_stop(&nbs->cc[enbsCCcombine]);
+                pairSearch->cycleCounting_.stop(PairSearch::enbsCCcombine);
             }
         }
     }
@@ -4251,7 +4242,7 @@ nonbonded_verlet_t::PairlistSets::construct(const InteractionLocality  iLocality
     {
         if (nnbl > 1 && checkRebalanceSimpleLists(nbl_list))
         {
-            rebalanceSimpleLists(nbl_list->nnbl, nbl_list->nbl, nbl_list->nbl_work, nbs->work);
+            rebalanceSimpleLists(nbl_list->nnbl, nbl_list->nbl, nbl_list->nbl_work, pairSearch->work());
 
             /* Swap the pointer of the sets of pair lists */
             NbnxnPairlistCpu **tmp = nbl_list->nbl;
@@ -4282,13 +4273,13 @@ nonbonded_verlet_t::PairlistSets::construct(const InteractionLocality  iLocality
 
     if (nbat->bUseBufferFlags)
     {
-        reduce_buffer_flags(nbs, nbl_list->nnbl, &nbat->buffer_flags);
+        reduce_buffer_flags(*pairSearch, nbl_list->nnbl, &nbat->buffer_flags);
     }
 
-    if (nbs->gridSet().haveFep())
+    if (pairSearch->gridSet().haveFep())
     {
         /* Balance the free-energy lists over all the threads */
-        balance_fep_lists(nbs, nbl_list);
+        balance_fep_lists(pairSearch->work(), nbl_list);
     }
 
     if (nbl_list->bSimple)
@@ -4312,13 +4303,13 @@ nonbonded_verlet_t::PairlistSets::construct(const InteractionLocality  iLocality
     /* Special performance logging stuff (env.var. GMX_NBNXN_CYCLE) */
     if (iLocality == InteractionLocality::Local)
     {
-        nbs->search_count++;
+        pairSearch->cycleCounting_.searchCount_++;
     }
-    if (nbs->print_cycles &&
-        (!nbs->domainSetup().haveDomDec || iLocality == InteractionLocality::NonLocal) &&
-        nbs->search_count % 100 == 0)
+    if (pairSearch->cycleCounting_.recordCycles_ &&
+        (!pairSearch->domainSetup().haveDomDec || iLocality == InteractionLocality::NonLocal) &&
+        pairSearch->cycleCounting_.searchCount_ % 100 == 0)
     {
-        nbs_cycle_print(stderr, nbs);
+        pairSearch->cycleCounting_.printCycles(stderr, pairSearch->work());
     }
 
     /* If we have more than one list, they either got rebalancing (CPU)
@@ -4330,12 +4321,12 @@ nonbonded_verlet_t::PairlistSets::construct(const InteractionLocality  iLocality
         {
             for (int t = 0; t < nbl_list->nnbl; t++)
             {
-                print_nblist_statistics(debug, nbl_list->nbl[t], nbs, rlist);
+                print_nblist_statistics(debug, nbl_list->nbl[t], *pairSearch, rlist);
             }
         }
         else
         {
-            print_nblist_statistics(debug, nbl_list->nblGpu[0], nbs, rlist);
+            print_nblist_statistics(debug, nbl_list->nblGpu[0], *pairSearch, rlist);
         }
     }
 
@@ -4374,7 +4365,7 @@ nonbonded_verlet_t::constructPairlist(const Nbnxm::InteractionLocality  iLocalit
                                       int64_t                           step,
                                       t_nrnb                           *nrnb)
 {
-    pairlistSets_->construct(iLocality, nbs.get(), nbat.get(), excl,
+    pairlistSets_->construct(iLocality, pairSearch_.get(), nbat.get(), excl,
                              kernelSetup_.kernelType,
                              step, nrnb);