NVIDIA Volta performance tweaks
authorSzilárd Páll <pall.szilard@gmail.com>
Mon, 4 Sep 2017 15:26:59 +0000 (17:26 +0200)
committerMark Abraham <mark.j.abraham@gmail.com>
Mon, 11 Sep 2017 15:09:24 +0000 (17:09 +0200)
Removed ballot syncs and replaced all computed masks with full warp
mask (as all branches in question are warp-synchronous).
This improves performance by 7-12%.

Change-Id: I769d6d8f0d171eb528d30868d567624d5e246dbf

src/gromacs/mdlib/nbnxn_cuda/nbnxn_cuda_kernel.cuh

index a9411e42ec8aef631f8e62afe8aa2a705da51c1a..00bc3b2b926f16ada3d90bd4a3906186036179f3 100644 (file)
  * shuffle-based reduction, hence CC >= 3.0.
  *
  *
- * NOTEs / TODO on Volta / CUDA 9 support extensions:
- * - the current way of computing active mask using ballot_sync() should be
- *   reconsidered: we can compute all masks with bitwise ops iso ballot and
- *   secondly, all conditionals are warp-uniform, so the sync is not needed;
- * - reconsider the use of __syncwarp(): its only role is currently to prevent
+ * NOTEs on Volta / CUDA 9 extensions:
+ *
+ * - While active thread masks are required for the warp collectives
+ *   (we use any and shfl), the kernel is designed such that all conditions
+ *   (other than the inner-most distance check) including loop trip counts
+ *   are warp-synchronous. Therefore, we don't need ballot to compute the
+ *   active masks as these are all full-warp masks.
+ *
+ * - TODO: reconsider the use of __syncwarp(): its only role is currently to prevent
  *   WAR hazard due to the cj preload; we should try to replace it with direct
  *   loads (which may be faster given the improved L1 on Volta).
  */
@@ -351,8 +355,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
 
 #endif                                  /* CALC_ENERGIES */
 
-    int          j4LoopStart      = cij4_start + tidxz;
-    unsigned int j4LoopThreadMask = gmx_ballot_sync(c_fullWarpMask, j4LoopStart < cij4_end);
+    int          j4LoopStart = cij4_start + tidxz;
     /* loop over the j clusters = seen by any of the atoms in the current super-cluster */
     for (j4 = j4LoopStart; j4 < cij4_end; j4 += NTHREAD_Z)
     {
@@ -360,9 +363,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
         imask       = pl_cj4[j4].imei[widx].imask;
         wexcl       = excl[wexcl_idx].pair[(tidx) & (warp_size - 1)];
 
-        unsigned int imaskSkipConditionThreadMask = j4LoopThreadMask;
 #ifndef PRUNE_NBL
-        imaskSkipConditionThreadMask = gmx_ballot_sync(j4LoopThreadMask, imask);
         if (imask)
 #endif
         {
@@ -371,7 +372,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
             {
                 cjs[tidxi + tidxj * c_nbnxnGpuJgroupSize/c_splitClSize] = pl_cj4[j4].cj[tidxi];
             }
-            gmx_syncwarp(imaskSkipConditionThreadMask);
+            gmx_syncwarp(c_fullWarpMask);
 
             /* Unrolling this loop
                - with pruning leads to register spilling;
@@ -379,8 +380,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
                Tested with up to nvcc 7.5 */
             for (jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
             {
-                const unsigned int jmSkipCondition           = imask & (superClInteractionMask << (jm * c_numClPerSupercl));
-                const unsigned int jmSkipConditionThreadMask = gmx_ballot_sync(imaskSkipConditionThreadMask, jmSkipCondition);
+                const unsigned int jmSkipCondition = imask & (superClInteractionMask << (jm * c_numClPerSupercl));
                 if (jmSkipCondition)
                 {
                     mask_ji = (1U << (jm * c_numClPerSupercl));
@@ -405,8 +405,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
 #endif
                     for (i = 0; i < c_numClPerSupercl; i++)
                     {
-                        const unsigned int iInnerSkipCondition           = imask & mask_ji;
-                        const unsigned int iInnerSkipConditionThreadMask = gmx_ballot_sync(jmSkipConditionThreadMask, iInnerSkipCondition);
+                        const unsigned int iInnerSkipCondition = imask & mask_ji;
                         if (iInnerSkipCondition)
                         {
                             ci      = sci * c_numClPerSupercl + i; /* i cluster index */
@@ -423,7 +422,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
                             /* If _none_ of the atoms pairs are in cutoff range,
                                the bit corresponding to the current
                                cluster-pair in imask gets set to 0. */
-                            if (!gmx_any_sync(iInnerSkipConditionThreadMask, r2 < rlist_sq))
+                            if (!gmx_any_sync(c_fullWarpMask, r2 < rlist_sq))
                             {
                                 imask &= ~mask_ji;
                             }
@@ -586,7 +585,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
                     }
 
                     /* reduce j forces */
-                    reduce_force_j_warp_shfl(fcj_buf, f, tidxi, aj, jmSkipConditionThreadMask);
+                    reduce_force_j_warp_shfl(fcj_buf, f, tidxi, aj, c_fullWarpMask);
                 }
             }
 #ifdef PRUNE_NBL
@@ -596,9 +595,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
 #endif
         }
         // avoid shared memory WAR hazards between loop iterations
-        gmx_syncwarp(j4LoopThreadMask);
-        // update thread mask for next loop iteration
-        j4LoopThreadMask = gmx_ballot_sync(j4LoopThreadMask, (j4 + NTHREAD_Z) < cij4_end);
+        gmx_syncwarp(c_fullWarpMask);
     }
 
     /* skip central shifts when summing shift forces */