Improved the intra-GPU load balancing
[alexxy/gromacs.git] / src / gromacs / mdlib / nbnxn_search.c
index e28c0c0e2d486d0a769e8ab469cc8fcda98e8b64..7e278a94ee814f077610956989ae9eb36e5fc703 100644 (file)
@@ -2675,11 +2675,6 @@ static void print_nblist_statistics_supersub(FILE *fp, const nbnxn_pairlist_t *n
             nbl->nci_tot/(double)grid->nsubc_tot*grid->na_c,
             nbl->nci_tot/(double)grid->nsubc_tot*grid->na_c/(0.5*4.0/3.0*M_PI*rl*rl*rl*grid->nsubc_tot*grid->na_c/(grid->size[XX]*grid->size[YY]*grid->size[ZZ])));
 
-    fprintf(fp, "nbl average j super cell list length %.1f\n",
-            0.25*nbl->ncj4/(double)nbl->nsci);
-    fprintf(fp, "nbl average i sub cell list length %.1f\n",
-            nbl->nci_tot/((double)nbl->ncj4));
-
     sum_nsp  = 0;
     sum_nsp2 = 0;
     nsp_max  = 0;
@@ -3957,10 +3952,11 @@ static void close_ci_entry_simple(nbnxn_pairlist_t *nbl)
  * both on nthread and our own thread index.
  */
 static void split_sci_entry(nbnxn_pairlist_t *nbl,
-                            int nsp_max_av, gmx_bool progBal, int nc_bal,
+                            int nsp_target_av,
+                            gmx_bool progBal, int nsp_tot_est,
                             int thread, int nthread)
 {
-    int nsci_est;
+    int nsp_est;
     int nsp_max;
     int cj4_start, cj4_end, j4len, cj4;
     int sci;
@@ -3972,19 +3968,27 @@ static void split_sci_entry(nbnxn_pairlist_t *nbl,
         /* Estimate the total numbers of ci's of the nblist combined
          * over all threads using the target number of ci's.
          */
-        nsci_est = nc_bal*thread/nthread + nbl->nsci;
+        nsp_est = (nsp_tot_est*thread)/nthread + nbl->nci_tot;
 
         /* The first ci blocks should be larger, to avoid overhead.
          * The last ci blocks should be smaller, to improve load balancing.
+         * The factor 3/2 makes the first block 3/2 times the target average
+         * and ensures that the total number of blocks end up equal to
+         * that with of equally sized blocks of size nsp_target_av.
          */
-        nsp_max = max(1,
-                      nsp_max_av*nc_bal*3/(2*(nsci_est - 1 + nc_bal)));
+        nsp_max = nsp_target_av*nsp_tot_est*3/(2*(nsp_est + nsp_tot_est));
     }
     else
     {
-        nsp_max = nsp_max_av;
+        nsp_max = nsp_target_av;
     }
 
+    /* Since nsp_max is a maximum/cut-off (this avoids high outliers,
+     * which lead to load imbalance), not an average, we add half the
+     * number of pairs in a cj4 block to get the average about right.
+     */
+    nsp_max += GPU_NSUBCELL*NBNXN_GPU_JGROUP_SIZE/2;
+
     cj4_start = nbl->sci[nbl->nsci-1].cj4_ind_start;
     cj4_end   = nbl->sci[nbl->nsci-1].cj4_ind_end;
     j4len     = cj4_end - cj4_start;
@@ -4049,7 +4053,7 @@ static void split_sci_entry(nbnxn_pairlist_t *nbl,
 /* Clost this super/sub list i entry */
 static void close_ci_entry_supersub(nbnxn_pairlist_t *nbl,
                                     int nsp_max_av,
-                                    gmx_bool progBal, int nc_bal,
+                                    gmx_bool progBal, int nsp_tot_est,
                                     int thread, int nthread)
 {
     int j4len, tlen;
@@ -4072,7 +4076,8 @@ static void close_ci_entry_supersub(nbnxn_pairlist_t *nbl,
         if (nsp_max_av > 0)
         {
             /* Measure the size of the new entry and potentially split it */
-            split_sci_entry(nbl, nsp_max_av, progBal, nc_bal, thread, nthread);
+            split_sci_entry(nbl, nsp_max_av, progBal, nsp_tot_est,
+                            thread, nthread);
         }
     }
 }
@@ -4344,11 +4349,17 @@ static real nonlocal_vol2(const gmx_domdec_zones_t *zones, rvec ls, real r)
 }
 
 /* Estimates the average size of a full j-list for super/sub setup */
-static int get_nsubpair_max(const nbnxn_search_t nbs,
-                            int                  iloc,
-                            real                 rlist,
-                            int                  min_ci_balanced)
+static void get_nsubpair_target(const nbnxn_search_t  nbs,
+                                int                   iloc,
+                                real                  rlist,
+                                int                   min_ci_balanced,
+                                int                  *nsubpair_target,
+                                int                  *nsubpair_tot_est)
 {
+    /* The target value of 36 seems to be the optimum for Kepler.
+     * Maxwell is less sensitive to the exact value.
+     */
+    const int           nsubpair_target_min = 36;
     const nbnxn_grid_t *grid;
     rvec                ls;
     real                xy_diag2, r_eff_sup, vol_est, nsp_est, nsp_est_nl;
@@ -4358,8 +4369,11 @@ static int get_nsubpair_max(const nbnxn_search_t nbs,
 
     if (min_ci_balanced <= 0 || grid->nc >= min_ci_balanced || grid->nc == 0)
     {
-        /* We don't need to worry */
-        return -1;
+        /* We don't need to balance the list sizes */
+        *nsubpair_target  = 0;
+        *nsubpair_tot_est = 0;
+
+        return;
     }
 
     ls[XX] = (grid->c1[XX] - grid->c0[XX])/(grid->ncx*GPU_NSUBCELL_X);
@@ -4410,22 +4424,19 @@ static int get_nsubpair_max(const nbnxn_search_t nbs,
         nsp_est = nsp_est_nl;
     }
 
-    /* Thus the (average) maximum j-list size should be as follows */
-    nsubpair_max = max(1, (int)(nsp_est/min_ci_balanced+0.5));
-
-    /* Since the target value is a maximum (this avoids high outliers,
-     * which lead to load imbalance), not average, we add half the
-     * number of pairs in a cj4 block to get the average about right.
+    /* Thus the (average) maximum j-list size should be as follows.
+     * Since there is overhead, we shouldn't make the lists too small
+     * (and we can't chop up j-groups) so we use a minimum target size of 36.
      */
-    nsubpair_max += GPU_NSUBCELL*NBNXN_GPU_JGROUP_SIZE/2;
+    *nsubpair_target  = max(nsubpair_target_min,
+                            (int)(nsp_est/min_ci_balanced + 0.5));
+    *nsubpair_tot_est = (int)nsp_est;
 
     if (debug)
     {
-        fprintf(debug, "nbl nsp estimate %.1f, nsubpair_max %d\n",
-                nsp_est, nsubpair_max);
+        fprintf(debug, "nbl nsp estimate %.1f, nsubpair_target %d\n",
+                nsp_est, *nsubpair_target);
     }
-
-    return nsubpair_max;
 }
 
 /* Debug list print function */
@@ -4828,7 +4839,7 @@ static void nbnxn_make_pairlist_part(const nbnxn_search_t nbs,
                                      gmx_bool bFBufferFlag,
                                      int nsubpair_max,
                                      gmx_bool progBal,
-                                     int min_ci_balanced,
+                                     int nsubpair_tot_est,
                                      int th, int nth,
                                      nbnxn_pairlist_t *nbl,
                                      t_nblist *nbl_fep)
@@ -5422,7 +5433,7 @@ static void nbnxn_make_pairlist_part(const nbnxn_search_t nbs,
                     {
                         close_ci_entry_supersub(nbl,
                                                 nsubpair_max,
-                                                progBal, min_ci_balanced,
+                                                progBal, nsubpair_tot_est,
                                                 th, nth);
                     }
                 }
@@ -5616,7 +5627,7 @@ void nbnxn_make_pairlist(const nbnxn_search_t  nbs,
     nbnxn_grid_t      *gridi, *gridj;
     gmx_bool           bGPUCPU;
     int                nzi, zi, zj0, zj1, zj;
-    int                nsubpair_max;
+    int                nsubpair_target, nsubpair_tot_est;
     int                th;
     int                nnbl;
     nbnxn_pairlist_t **nbl;
@@ -5686,11 +5697,13 @@ void nbnxn_make_pairlist(const nbnxn_search_t  nbs,
 
     if (!nbl_list->bSimple && min_ci_balanced > 0)
     {
-        nsubpair_max = get_nsubpair_max(nbs, iloc, rlist, min_ci_balanced);
+        get_nsubpair_target(nbs, iloc, rlist, min_ci_balanced,
+                            &nsubpair_target, &nsubpair_tot_est);
     }
     else
     {
-        nsubpair_max = 0;
+        nsubpair_target  = 0;
+        nsubpair_tot_est = 0;
     }
 
     /* Clear all pair-lists */
@@ -5767,8 +5780,8 @@ void nbnxn_make_pairlist(const nbnxn_search_t  nbs,
                                          nb_kernel_type,
                                          ci_block,
                                          nbat->bUseBufferFlags,
-                                         nsubpair_max,
-                                         progBal, min_ci_balanced,
+                                         nsubpair_target,
+                                         progBal, nsubpair_tot_est,
                                          th, nnbl,
                                          nbl[th],
                                          nbl_list->nbl_fep[th]);