Merge branch 'release-2016'
[alexxy/gromacs.git] / src / gromacs / mdlib / nbnxn_cuda / nbnxn_cuda_kernel.cuh
index c4ec038d2c4504a46c70d8d66ce3c1b475bf1377..c0ef88571064e342206c18a5594ece1d153b73fa 100644 (file)
  * shuffle-based reduction, hence CC >= 3.0.
  *
  *
- * NOTEs / TODO on Volta / CUDA 9 support extensions:
- * - the current way of computing active mask using ballot_sync() should be
- *   reconsidered: we can compute all masks with bitwise ops iso ballot and
- *   secondly, all conditionals are warp-uniform, so the sync is not needed;
- * - reconsider the use of __syncwarp(): its only role is currently to prevent
+ * NOTEs on Volta / CUDA 9 extensions:
+ *
+ * - While active thread masks are required for the warp collectives
+ *   (we use any and shfl), the kernel is designed such that all conditions
+ *   (other than the inner-most distance check) including loop trip counts
+ *   are warp-synchronous. Therefore, we don't need ballot to compute the
+ *   active masks as these are all full-warp masks.
+ *
+ * - TODO: reconsider the use of __syncwarp(): its only role is currently to prevent
  *   WAR hazard due to the cj preload; we should try to replace it with direct
  *   loads (which may be faster given the improved L1 on Volta).
  */
@@ -358,8 +362,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
     const int nonSelfInteraction  = !(nb_sci.shift == CENTRAL & tidxj <= tidxi);
 #endif
 
-    int          j4LoopStart      = cij4_start + tidxz;
-    unsigned int j4LoopThreadMask = gmx_ballot_sync(c_fullWarpMask, j4LoopStart < cij4_end);
+    int          j4LoopStart = cij4_start + tidxz;
     /* loop over the j clusters = seen by any of the atoms in the current super-cluster */
     for (j4 = j4LoopStart; j4 < cij4_end; j4 += NTHREAD_Z)
     {
@@ -367,9 +370,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
         imask       = pl_cj4[j4].imei[widx].imask;
         wexcl       = excl[wexcl_idx].pair[(tidx) & (warp_size - 1)];
 
-        unsigned int imaskSkipConditionThreadMask = j4LoopThreadMask;
 #ifndef PRUNE_NBL
-        imaskSkipConditionThreadMask = gmx_ballot_sync(j4LoopThreadMask, imask);
         if (imask)
 #endif
         {
@@ -378,7 +379,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
             {
                 cjs[tidxi + tidxj * c_nbnxnGpuJgroupSize/c_splitClSize] = pl_cj4[j4].cj[tidxi];
             }
-            gmx_syncwarp(imaskSkipConditionThreadMask);
+            gmx_syncwarp(c_fullWarpMask);
 
             /* Unrolling this loop
                - with pruning leads to register spilling;
@@ -386,8 +387,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
                Tested with up to nvcc 7.5 */
             for (jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
             {
-                const unsigned int jmSkipCondition           = imask & (superClInteractionMask << (jm * c_numClPerSupercl));
-                const unsigned int jmSkipConditionThreadMask = gmx_ballot_sync(imaskSkipConditionThreadMask, jmSkipCondition);
+                const unsigned int jmSkipCondition = imask & (superClInteractionMask << (jm * c_numClPerSupercl));
                 if (jmSkipCondition)
                 {
                     mask_ji = (1U << (jm * c_numClPerSupercl));
@@ -412,8 +412,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
 #endif
                     for (i = 0; i < c_numClPerSupercl; i++)
                     {
-                        const unsigned int iInnerSkipCondition           = imask & mask_ji;
-                        const unsigned int iInnerSkipConditionThreadMask = gmx_ballot_sync(jmSkipConditionThreadMask, iInnerSkipCondition);
+                        const unsigned int iInnerSkipCondition = imask & mask_ji;
                         if (iInnerSkipCondition)
                         {
                             ci      = sci * c_numClPerSupercl + i; /* i cluster index */
@@ -430,7 +429,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
                             /* If _none_ of the atoms pairs are in cutoff range,
                                the bit corresponding to the current
                                cluster-pair in imask gets set to 0. */
-                            if (!gmx_any_sync(iInnerSkipConditionThreadMask, r2 < rlist_sq))
+                            if (!gmx_any_sync(c_fullWarpMask, r2 < rlist_sq))
                             {
                                 imask &= ~mask_ji;
                             }
@@ -591,7 +590,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
                     }
 
                     /* reduce j forces */
-                    reduce_force_j_warp_shfl(fcj_buf, f, tidxi, aj, jmSkipConditionThreadMask);
+                    reduce_force_j_warp_shfl(fcj_buf, f, tidxi, aj, c_fullWarpMask);
                 }
             }
 #ifdef PRUNE_NBL
@@ -601,9 +600,7 @@ __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
 #endif
         }
         // avoid shared memory WAR hazards between loop iterations
-        gmx_syncwarp(j4LoopThreadMask);
-        // update thread mask for next loop iteration
-        j4LoopThreadMask = gmx_ballot_sync(j4LoopThreadMask, (j4 + NTHREAD_Z) < cij4_end);
+        gmx_syncwarp(c_fullWarpMask);
     }
 
     /* skip central shifts when summing shift forces */