Use gmx::Range for iZones in pair search
authorBerk Hess <hess@kth.se>
Wed, 25 Sep 2019 15:15:42 +0000 (17:15 +0200)
committerArtem Zhmurov <zhmurov@gmail.com>
Thu, 3 Oct 2019 11:32:00 +0000 (13:32 +0200)
Change-Id: Ifed1ac3ed2fc0b02680a33d4e44620f82248dda9

src/gromacs/nbnxm/pairlist.cpp

index ff60f433981dbad9d856eeffc1b3709ec7fe0c8a..8b55a3a9ddd72fc2a829f260f51d17645999e06f 100644 (file)
@@ -3963,6 +3963,61 @@ static void sort_sci(NbnxnPairlistGpu *nbl)
     std::swap(nbl->sci, work.sci_sort);
 }
 
+/* Returns the i-zone range for pairlist construction for the give locality */
+static Range<int>
+getIZoneRange(const Nbnxm::GridSet::DomainSetup &domainSetup,
+              const Nbnxm::InteractionLocality   locality)
+{
+    if (domainSetup.doTestParticleInsertion)
+    {
+        /* With TPI we do grid 1, the inserted molecule, versus grid 0, the rest */
+        return {
+                   1, 2
+        };
+    }
+    else if (locality == InteractionLocality::Local)
+    {
+        /* Local: only zone (grid) 0 vs 0 */
+        return {
+                   0, 1
+        };
+    }
+    else
+    {
+        /* Non-local: we need all i-zones */
+        return {
+                   0, int(domainSetup.zones->iZones.size())
+        };
+    }
+}
+
+/* Returns the j-zone range for pairlist construction for the give locality and i-zone */
+static Range<int>
+getJZoneRange(const gmx_domdec_zones_t          &ddZones,
+              const Nbnxm::InteractionLocality   locality,
+              const int                          iZone)
+{
+    if (locality == InteractionLocality::Local)
+    {
+        /* Local: zone 0 vs 0 or with TPI 1 vs 0 */
+        return {
+                   0, 1
+        };
+    }
+    else if (iZone == 0)
+    {
+        /* Non-local: we need to avoid the local (zone 0 vs 0) interactions */
+        return {
+                   1, *ddZones.iZones[iZone].jZoneRange.end()
+        };
+    }
+    else
+    {
+        /* Non-local with non-local i-zone: use all j-zones */
+        return ddZones.iZones[iZone].jZoneRange;
+    }
+}
+
 //! Prepares CPU lists produced by the search for dynamic pruning
 static void prepareListsForDynamicPruning(gmx::ArrayRef<NbnxnPairlistCpu> lists);
 
@@ -3997,17 +4052,6 @@ PairlistSet::constructPairlists(const Nbnxm::GridSet          &gridSet,
         init_buffer_flags(&nbat->buffer_flags, nbat->numAtoms());
     }
 
-    int nzi;
-    if (locality_ == InteractionLocality::Local)
-    {
-        /* Only zone (grid) 0 vs 0 */
-        nzi = 1;
-    }
-    else
-    {
-        nzi = gridSet.domainSetup().zones->iZones.size();
-    }
-
     if (!isCpuType_ && minimumIlistCountForGpuBalancing > 0)
     {
         get_nsubpair_target(gridSet, locality_, rlist, minimumIlistCountForGpuBalancing,
@@ -4037,34 +4081,23 @@ PairlistSet::constructPairlists(const Nbnxm::GridSet          &gridSet,
         }
     }
 
-    const gmx_domdec_zones_t *ddZones = gridSet.domainSetup().zones;
+    const gmx_domdec_zones_t &ddZones = *gridSet.domainSetup().zones;
+
+    const auto iZoneRange = getIZoneRange(gridSet.domainSetup(), locality_);
 
-    for (int zi = 0; zi < nzi; zi++)
+    for (const int iZone : iZoneRange)
     {
-        /* With TPI we do grid 1, the inserted molecule, versus grid 0, the rest */
-        if (gridSet.domainSetup().doTestParticleInsertion)
-        {
-            zi = 1;
-        }
-        const Grid &iGrid = gridSet.grids()[zi];
+        const Grid &iGrid = gridSet.grids()[iZone];
+
+        const auto jZoneRange = getJZoneRange(ddZones, locality_, iZone);
 
-        Range<int> jZoneRange;
-        if (locality_ == InteractionLocality::Local)
-        {
-            jZoneRange = Range<int>(0, 1);
-        }
-        else
-        {
-            jZoneRange = Range<int>(ddZones->iZones[zi].jZoneRange.begin() + (zi == 0 ? 1 : 0),
-                                    ddZones->iZones[zi].jZoneRange.end());
-        }
         for (int jZone : jZoneRange)
         {
             const Grid &jGrid = gridSet.grids()[jZone];
 
             if (debug)
             {
-                fprintf(debug, "ns search grid %d vs %d\n", zi, jZone);
+                fprintf(debug, "ns search grid %d vs %d\n", iZone, jZone);
             }
 
             searchCycleCounting->start(enbsCCsearch);
@@ -4074,7 +4107,7 @@ PairlistSet::constructPairlists(const Nbnxm::GridSet          &gridSet,
             /* With GPU: generate progressively smaller lists for
              * load balancing for local only or non-local with 2 zones.
              */
-            progBal = (locality_ == InteractionLocality::Local || ddZones->n <= 2);
+            progBal = (locality_ == InteractionLocality::Local || ddZones.n <= 2);
 
 #pragma omp parallel for num_threads(numLists) schedule(static)
             for (int th = 0; th < numLists; th++)
@@ -4084,7 +4117,7 @@ PairlistSet::constructPairlists(const Nbnxm::GridSet          &gridSet,
                     /* Re-init the thread-local work flag data before making
                      * the first list (not an elegant conditional).
                      */
-                    if (nbat->bUseBufferFlags && (zi == 0 && jZone == 0))
+                    if (nbat->bUseBufferFlags && (iZone == 0 && jZone == 0))
                     {
                         init_buffer_flags(&searchWork[th].buffer_flags, nbat->numAtoms());
                     }