Move nbnxm domainSetup to GridSet
[alexxy/gromacs.git] / src / gromacs / nbnxm / pairlist.cpp
index aba3775eb93a5a58d2ded242bb06c1416eacfc39..847cf591d8832c9a3a20dc18fad56ba4ba74b162 100644 (file)
@@ -750,10 +750,10 @@ PairlistSet::PairlistSet(const Nbnxm::InteractionLocality  locality,
 /* Print statistics of a pair list, used for debug output */
 static void print_nblist_statistics(FILE                   *fp,
                                     const NbnxnPairlistCpu &nbl,
-                                    const PairSearch       &pairSearch,
+                                    const Nbnxm::GridSet   &gridSet,
                                     const real              rl)
 {
-    const Grid             &grid = pairSearch.gridSet().grids()[0];
+    const Grid             &grid = gridSet.grids()[0];
     const Grid::Dimensions &dims = grid.dimensions();
 
     fprintf(fp, "nbl nci %zu ncj %d\n",
@@ -797,10 +797,10 @@ static void print_nblist_statistics(FILE                   *fp,
 /* Print statistics of a pair lists, used for debug output */
 static void print_nblist_statistics(FILE                   *fp,
                                     const NbnxnPairlistGpu &nbl,
-                                    const PairSearch       &pairSearch,
+                                    const Nbnxm::GridSet   &gridSet,
                                     const real              rl)
 {
-    const Grid             &grid = pairSearch.gridSet().grids()[0];
+    const Grid             &grid = gridSet.grids()[0];
     const Grid::Dimensions &dims = grid.dimensions();
 
     fprintf(fp, "nbl nsci %zu ncj4 %zu nsi %d excl4 %zu\n",
@@ -2489,7 +2489,7 @@ static real nonlocal_vol2(const struct gmx_domdec_zones_t *zones, const rvec ls,
 }
 
 /* Estimates the average size of a full j-list for super/sub setup */
-static void get_nsubpair_target(const PairSearch          &pairSearch,
+static void get_nsubpair_target(const Nbnxm::GridSet      &gridSet,
                                 const InteractionLocality  iloc,
                                 const real                 rlist,
                                 const int                  min_ci_balanced,
@@ -2502,7 +2502,7 @@ static void get_nsubpair_target(const PairSearch          &pairSearch,
     const int           nsubpair_target_min = 36;
     real                r_eff_sup, vol_est, nsp_est, nsp_est_nl;
 
-    const Grid         &grid = pairSearch.gridSet().grids()[0];
+    const Grid         &grid = gridSet.grids()[0];
 
     /* We don't need to balance list sizes if:
      * - We didn't request balancing.
@@ -2530,8 +2530,8 @@ static void get_nsubpair_target(const PairSearch          &pairSearch,
     /* The formulas below are a heuristic estimate of the average nsj per si*/
     r_eff_sup = rlist + nbnxn_get_rlist_effective_inc(numAtomsCluster, ls);
 
-    if (!pairSearch.domainSetup().haveDomDec ||
-        pairSearch.domainSetup().zones->n == 1)
+    if (!gridSet.domainSetup().haveMultipleDomains ||
+        gridSet.domainSetup().zones->n == 1)
     {
         nsp_est_nl = 0;
     }
@@ -2539,7 +2539,7 @@ static void get_nsubpair_target(const PairSearch          &pairSearch,
     {
         nsp_est_nl =
             gmx::square(dims.atomDensity/numAtomsCluster)*
-            nonlocal_vol2(pairSearch.domainSetup().zones, ls, r_eff_sup);
+            nonlocal_vol2(gridSet.domainSetup().zones, ls, r_eff_sup);
     }
 
     if (iloc == InteractionLocality::Local)
@@ -2912,7 +2912,7 @@ static float boundingbox_only_distance2(const Grid::Dimensions &iGridDims,
 }
 
 static int get_ci_block_size(const Grid &iGrid,
-                             const bool  haveDomDec,
+                             const bool  haveMultipleDomains,
                              const int   numLists)
 {
     const int ci_block_enum      = 5;
@@ -2946,7 +2946,7 @@ static int get_ci_block_size(const Grid &iGrid,
     /* Without domain decomposition
      * or with less than 3 blocks per task, divide in nth blocks.
      */
-    if (!haveDomDec || numLists*3*ci_block > iGrid.numCells())
+    if (!haveMultipleDomains || numLists*3*ci_block > iGrid.numCells())
     {
         ci_block = (iGrid.numCells() + numLists - 1)/numLists;
     }
@@ -3128,7 +3128,7 @@ static void setBufferFlags(const NbnxnPairlistGpu gmx_unused &nbl,
 
 /* Generates the part of pair-list nbl assigned to our thread */
 template <typename T>
-static void nbnxn_make_pairlist_part(const PairSearch &pairSearch,
+static void nbnxn_make_pairlist_part(const Nbnxm::GridSet &gridSet,
                                      const Grid &iGrid,
                                      const Grid &jGrid,
                                      PairsearchWork *work,
@@ -3183,8 +3183,6 @@ static void nbnxn_make_pairlist_part(const PairSearch &pairSearch,
         gridj_flag       = work->buffer_flags.flag;
     }
 
-    const Nbnxm::GridSet &gridSet = pairSearch.gridSet();
-
     gridSet.getBox(box);
 
     const bool            haveFep = gridSet.haveFep();
@@ -3226,8 +3224,8 @@ static void nbnxn_make_pairlist_part(const PairSearch &pairSearch,
         /* Check if we need periodicity shifts.
          * Without PBC or with domain decomposition we don't need them.
          */
-        if (d >= ePBC2npbcdim(pairSearch.domainSetup().ePBC) ||
-            pairSearch.domainSetup().haveDomDecPerDim[d])
+        if (d >= ePBC2npbcdim(gridSet.domainSetup().ePBC) ||
+            gridSet.domainSetup().haveMultipleDomainsPerDim[d])
         {
             shp[d] = 0;
         }
@@ -3638,7 +3636,7 @@ static void nbnxn_make_pairlist_part(const PairSearch &pairSearch,
     {
         fprintf(debug, "number of distance checks %d\n", numDistanceChecks);
 
-        print_nblist_statistics(debug, *nbl, pairSearch, rlist);
+        print_nblist_statistics(debug, *nbl, gridSet, rlist);
 
         if (haveFep)
         {
@@ -3647,13 +3645,13 @@ static void nbnxn_make_pairlist_part(const PairSearch &pairSearch,
     }
 }
 
-static void reduce_buffer_flags(const PairSearch           &pairSearch,
-                                int                         nsrc,
-                                const nbnxn_buffer_flags_t *dest)
+static void reduce_buffer_flags(gmx::ArrayRef<PairsearchWork>  searchWork,
+                                int                            nsrc,
+                                const nbnxn_buffer_flags_t    *dest)
 {
     for (int s = 0; s < nsrc; s++)
     {
-        gmx_bitmask_t * flag = pairSearch.work()[s].buffer_flags.flag;
+        gmx_bitmask_t * flag = searchWork[s].buffer_flags.flag;
 
         for (int b = 0; b < dest->nflag; b++)
         {
@@ -3929,12 +3927,14 @@ static void sort_sci(NbnxnPairlistGpu *nbl)
 static void prepareListsForDynamicPruning(gmx::ArrayRef<NbnxnPairlistCpu> lists);
 
 void
-PairlistSet::constructPairlists(PairSearch                *pairSearch,
-                                nbnxn_atomdata_t          *nbat,
-                                const t_blocka            *excl,
-                                const Nbnxm::KernelType    kernelType,
-                                const int                  minimumIlistCountForGpuBalancing,
-                                t_nrnb                    *nrnb)
+PairlistSet::constructPairlists(const Nbnxm::GridSet          &gridSet,
+                                gmx::ArrayRef<PairsearchWork>  searchWork,
+                                nbnxn_atomdata_t              *nbat,
+                                const t_blocka                *excl,
+                                const Nbnxm::KernelType        kernelType,
+                                const int                      minimumIlistCountForGpuBalancing,
+                                t_nrnb                        *nrnb,
+                                SearchCycleCounting           *searchCycleCounting)
 {
     const real         rlist    = params_.rlistOuter;
 
@@ -3966,12 +3966,12 @@ PairlistSet::constructPairlists(PairSearch                *pairSearch,
     }
     else
     {
-        nzi = pairSearch->domainSetup().zones->nizone;
+        nzi = gridSet.domainSetup().zones->nizone;
     }
 
     if (!isCpuType_ && minimumIlistCountForGpuBalancing > 0)
     {
-        get_nsubpair_target(*pairSearch, locality_, rlist, minimumIlistCountForGpuBalancing,
+        get_nsubpair_target(gridSet, locality_, rlist, minimumIlistCountForGpuBalancing,
                             &nsubpair_target, &nsubpair_tot_est);
     }
     else
@@ -3998,11 +3998,11 @@ PairlistSet::constructPairlists(PairSearch                *pairSearch,
         }
     }
 
-    const gmx_domdec_zones_t *ddZones = pairSearch->domainSetup().zones;
+    const gmx_domdec_zones_t *ddZones = gridSet.domainSetup().zones;
 
     for (int zi = 0; zi < nzi; zi++)
     {
-        const Grid &iGrid = pairSearch->gridSet().grids()[zi];
+        const Grid &iGrid = gridSet.grids()[zi];
 
         int                 zj0;
         int                 zj1;
@@ -4022,16 +4022,16 @@ PairlistSet::constructPairlists(PairSearch                *pairSearch,
         }
         for (int zj = zj0; zj < zj1; zj++)
         {
-            const Grid &jGrid = pairSearch->gridSet().grids()[zj];
+            const Grid &jGrid = gridSet.grids()[zj];
 
             if (debug)
             {
                 fprintf(debug, "ns search grid %d vs %d\n", zi, zj);
             }
 
-            pairSearch->cycleCounting_.start(PairSearch::enbsCCsearch);
+            searchCycleCounting->start(enbsCCsearch);
 
-            ci_block = get_ci_block_size(iGrid, pairSearch->domainSetup().haveDomDec, numLists);
+            ci_block = get_ci_block_size(iGrid, gridSet.domainSetup().haveMultipleDomains, numLists);
 
             /* With GPU: generate progressively smaller lists for
              * load balancing for local only or non-local with 2 zones.
@@ -4048,7 +4048,7 @@ PairlistSet::constructPairlists(PairSearch                *pairSearch,
                      */
                     if (nbat->bUseBufferFlags && ((zi == 0 && zj == 0)))
                     {
-                        init_buffer_flags(&pairSearch->work()[th].buffer_flags, nbat->numAtoms());
+                        init_buffer_flags(&searchWork[th].buffer_flags, nbat->numAtoms());
                     }
 
                     if (combineLists_ && th > 0)
@@ -4058,17 +4058,17 @@ PairlistSet::constructPairlists(PairSearch                *pairSearch,
                         clear_pairlist(&gpuLists_[th]);
                     }
 
-                    PairsearchWork *searchWork = &pairSearch->work()[th];
+                    PairsearchWork &work = searchWork[th];
 
-                    searchWork->cycleCounter.start();
+                    work.cycleCounter.start();
 
                     t_nblist *fepListPtr = (fepLists_.empty() ? nullptr : fepLists_[th]);
 
                     /* Divide the i cells equally over the pairlists */
                     if (isCpuType_)
                     {
-                        nbnxn_make_pairlist_part(*pairSearch, iGrid, jGrid,
-                                                 searchWork, nbat, *excl,
+                        nbnxn_make_pairlist_part(gridSet, iGrid, jGrid,
+                                                 &work, nbat, *excl,
                                                  rlist,
                                                  kernelType,
                                                  ci_block,
@@ -4081,8 +4081,8 @@ PairlistSet::constructPairlists(PairSearch                *pairSearch,
                     }
                     else
                     {
-                        nbnxn_make_pairlist_part(*pairSearch, iGrid, jGrid,
-                                                 searchWork, nbat, *excl,
+                        nbnxn_make_pairlist_part(gridSet, iGrid, jGrid,
+                                                 &work, nbat, *excl,
                                                  rlist,
                                                  kernelType,
                                                  ci_block,
@@ -4094,18 +4094,18 @@ PairlistSet::constructPairlists(PairSearch                *pairSearch,
                                                  fepListPtr);
                     }
 
-                    searchWork->cycleCounter.stop();
+                    work.cycleCounter.stop();
                 }
                 GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
             }
-            pairSearch->cycleCounting_.stop(PairSearch::enbsCCsearch);
+            searchCycleCounting->stop(enbsCCsearch);
 
             np_tot = 0;
             np_noq = 0;
             np_hlj = 0;
             for (int th = 0; th < numLists; th++)
             {
-                inc_nrnb(nrnb, eNR_NBNXN_DIST2, pairSearch->work()[th].ndistc);
+                inc_nrnb(nrnb, eNR_NBNXN_DIST2, searchWork[th].ndistc);
 
                 if (isCpuType_)
                 {
@@ -4137,12 +4137,12 @@ PairlistSet::constructPairlists(PairSearch                *pairSearch,
             {
                 GMX_ASSERT(!isCpuType_, "Can only combine GPU lists");
 
-                pairSearch->cycleCounting_.start(PairSearch::enbsCCcombine);
+                searchCycleCounting->start(enbsCCcombine);
 
                 combine_nblists(gmx::constArrayRefFromArray(&gpuLists_[1], numLists - 1),
                                 &gpuLists_[0]);
 
-                pairSearch->cycleCounting_.stop(PairSearch::enbsCCcombine);
+                searchCycleCounting->stop(enbsCCcombine);
             }
         }
     }
@@ -4151,7 +4151,7 @@ PairlistSet::constructPairlists(PairSearch                *pairSearch,
     {
         if (numLists > 1 && checkRebalanceSimpleLists(cpuLists_))
         {
-            rebalanceSimpleLists(cpuLists_, cpuListsWork_, pairSearch->work());
+            rebalanceSimpleLists(cpuLists_, cpuListsWork_, searchWork);
 
             /* Swap the sets of pair lists */
             cpuLists_.swap(cpuListsWork_);
@@ -4180,13 +4180,13 @@ PairlistSet::constructPairlists(PairSearch                *pairSearch,
 
     if (nbat->bUseBufferFlags)
     {
-        reduce_buffer_flags(*pairSearch, numLists, &nbat->buffer_flags);
+        reduce_buffer_flags(searchWork, numLists, &nbat->buffer_flags);
     }
 
-    if (pairSearch->gridSet().haveFep())
+    if (gridSet.haveFep())
     {
         /* Balance the free-energy lists over all the threads */
-        balance_fep_lists(fepLists_, pairSearch->work());
+        balance_fep_lists(fepLists_, searchWork);
     }
 
     if (isCpuType_)
@@ -4206,12 +4206,12 @@ PairlistSet::constructPairlists(PairSearch                *pairSearch,
         {
             for (auto &cpuList : cpuLists_)
             {
-                print_nblist_statistics(debug, cpuList, *pairSearch, rlist);
+                print_nblist_statistics(debug, cpuList, gridSet, rlist);
             }
         }
         else if (!isCpuType_ && gpuLists_.size() > 1)
         {
-            print_nblist_statistics(debug, gpuLists_[0], *pairSearch, rlist);
+            print_nblist_statistics(debug, gpuLists_[0], gridSet, rlist);
         }
     }
 
@@ -4253,7 +4253,9 @@ PairlistSets::construct(const InteractionLocality  iLocality,
                         const int64_t              step,
                         t_nrnb                    *nrnb)
 {
-    pairlistSet(iLocality).constructPairlists(pairSearch, nbat, excl, kernelType, minimumIlistCountForGpuBalancing_, nrnb);
+    pairlistSet(iLocality).constructPairlists(pairSearch->gridSet(), pairSearch->work(),
+                                              nbat, excl, kernelType, minimumIlistCountForGpuBalancing_,
+                                              nrnb, &pairSearch->cycleCounting_);
 
     if (iLocality == Nbnxm::InteractionLocality::Local)
     {
@@ -4271,7 +4273,7 @@ PairlistSets::construct(const InteractionLocality  iLocality,
         pairSearch->cycleCounting_.searchCount_++;
     }
     if (pairSearch->cycleCounting_.recordCycles_ &&
-        (!pairSearch->domainSetup().haveDomDec || iLocality == InteractionLocality::NonLocal) &&
+        (!pairSearch->gridSet().domainSetup().haveMultipleDomains || iLocality == InteractionLocality::NonLocal) &&
         pairSearch->cycleCounting_.searchCount_ % 100 == 0)
     {
         pairSearch->cycleCounting_.printCycles(stderr, pairSearch->work());