Allow disabling cj prefetch in the CUDA nbnxm kernels
authorSzilárd Páll <pall.szilard@gmail.com>
Wed, 29 Sep 2021 07:41:05 +0000 (07:41 +0000)
committerBerk Hess <hess@kth.se>
Wed, 29 Sep 2021 07:41:05 +0000 (07:41 +0000)
src/gromacs/nbnxm/cuda/nbnxm_cuda_kernel.cuh
src/gromacs/nbnxm/cuda/nbnxm_cuda_kernel_pruneonly.cuh
src/gromacs/nbnxm/cuda/nbnxm_cuda_kernel_utils.cuh

index 5aaf3b9ead96f5691677ab14031d77694adfbc5d..35036f391531709b83c2579e4987957bbece0e02 100644 (file)
  *   are warp-synchronous. Therefore, we don't need ballot to compute the
  *   active masks as these are all full-warp masks.
  *
- * - TODO: reconsider the use of __syncwarp(): its only role is currently to prevent
- *   WAR hazard due to the cj preload; we should try to replace it with direct
- *   loads (which may be faster given the improved L1 on Volta).
  */
 
 /* Kernel launch bounds for different compute capabilities. The value of NTHREAD_Z
@@ -252,6 +249,15 @@ __launch_bounds__(THREADS_PER_BLOCK)
     /*! i-cluster interaction mask for a super-cluster with all c_nbnxnGpuNumClusterPerSupercluster=8 bits set */
     const unsigned superClInteractionMask = ((1U << c_nbnxnGpuNumClusterPerSupercluster) - 1U);
 
+    // cj preload is off in the following cases:
+    // - sm_70 (V100), sm_80 (A100), sm_86 (GA02)
+    // - for future arch (> 8.6 at the time of writing) we assume it is better to keep it off
+    // cj preload is left on for:
+    // - sm_75: improvements +/- very small
+    // - sm_61: tested and slower without preload
+    // - sm_6x and earlier not tested to
+    constexpr bool c_preloadCj = (GMX_PTX_ARCH < 700 || GMX_PTX_ARCH == 750);
+
     /*********************************************************************
      * Set up shared memory pointers.
      * sm_nextSlotPtr should always be updated to point to the "next slot",
@@ -269,9 +275,12 @@ __launch_bounds__(THREADS_PER_BLOCK)
 
     /* shmem buffer for cj, for each warp separately */
     int* cjs = reinterpret_cast<int*>(sm_nextSlotPtr);
-    /* the cjs buffer's use expects a base pointer offset for pairs of warps in the j-concurrent execution */
-    cjs += tidxz * c_nbnxnGpuClusterpairSplit * c_nbnxnGpuJgroupSize;
-    sm_nextSlotPtr += (NTHREAD_Z * c_nbnxnGpuClusterpairSplit * c_nbnxnGpuJgroupSize * sizeof(*cjs));
+    if (c_preloadCj)
+    {
+        /* the cjs buffer's use expects a base pointer offset for pairs of warps in the j-concurrent execution */
+        cjs += tidxz * c_nbnxnGpuClusterpairSplit * c_nbnxnGpuJgroupSize;
+        sm_nextSlotPtr += (NTHREAD_Z * c_nbnxnGpuClusterpairSplit * c_nbnxnGpuJgroupSize * sizeof(*cjs));
+    }
 
 #    ifndef LJ_COMB
     /* shmem buffer for i atom-type pre-loading */
@@ -384,12 +393,15 @@ __launch_bounds__(THREADS_PER_BLOCK)
         if (imask)
 #    endif
         {
-            /* Pre-load cj into shared memory on both warps separately */
-            if ((tidxj == 0 | tidxj == 4) & (tidxi < c_nbnxnGpuJgroupSize))
+            if (c_preloadCj)
             {
-                cjs[tidxi + tidxj * c_nbnxnGpuJgroupSize / c_splitClSize] = pl_cj4[j4].cj[tidxi];
+                /* Pre-load cj into shared memory on both warps separately */
+                if ((tidxj == 0 | tidxj == 4) & (tidxi < c_nbnxnGpuJgroupSize))
+                {
+                    cjs[tidxi + tidxj * c_nbnxnGpuJgroupSize / c_splitClSize] = pl_cj4[j4].cj[tidxi];
+                }
+                __syncwarp(c_fullWarpMask);
             }
-            __syncwarp(c_fullWarpMask);
 
             /* Unrolling this loop
                - with pruning leads to register spilling;
@@ -401,7 +413,9 @@ __launch_bounds__(THREADS_PER_BLOCK)
                 {
                     mask_ji = (1U << (jm * c_nbnxnGpuNumClusterPerSupercluster));
 
-                    cj = cjs[jm + (tidxj & 4) * c_nbnxnGpuJgroupSize / c_splitClSize];
+                    cj = c_preloadCj ? cjs[jm + (tidxj & 4) * c_nbnxnGpuJgroupSize / c_splitClSize]
+                                     : cj = pl_cj4[j4].cj[jm];
+
                     aj = cj * c_clSize + tidxj;
 
                     /* load j atom data */
@@ -625,8 +639,11 @@ __launch_bounds__(THREADS_PER_BLOCK)
             pl_cj4[j4].imei[widx].imask = imask;
 #    endif
         }
-        // avoid shared memory WAR hazards between loop iterations
-        __syncwarp(c_fullWarpMask);
+        if (c_preloadCj)
+        {
+            // avoid shared memory WAR hazards on sm_cjs between loop iterations
+            __syncwarp(c_fullWarpMask);
+        }
     }
 
     /* skip central shifts when summing shift forces */
index 8219ad16a351e60e931fb14e1facb902479279df..fc7df11f5893ef825cdbcf740fe82b7d8c51a8e6 100644 (file)
@@ -141,6 +141,11 @@ nbnxn_kernel_prune_cuda<false>(const NBAtomDataGpu, const NBParamGpu, const Nbnx
     unsigned int bidx  = blockIdx.x;
     unsigned int widx  = (threadIdx.y * c_clSize) / warp_size; /* warp index */
 
+    // cj preload is off in the following cases:
+    // - sm_70 (V100), sm_8x (A100, GA100), sm_75 (TU102)
+    // - for future arch (> 8.6 at the time of writing) we assume it is better to keep it off
+    constexpr bool c_preloadCj = (GMX_PTX_ARCH < 700);
+
     /*********************************************************************
      * Set up shared memory pointers.
      * sm_nextSlotPtr should always be updated to point to the "next slot",
@@ -157,9 +162,12 @@ nbnxn_kernel_prune_cuda<false>(const NBAtomDataGpu, const NBParamGpu, const Nbnx
 
     /* shmem buffer for cj, for each warp separately */
     int* cjs = reinterpret_cast<int*>(sm_nextSlotPtr);
-    /* the cjs buffer's use expects a base pointer offset for pairs of warps in the j-concurrent execution */
-    cjs += tidxz * c_nbnxnGpuClusterpairSplit * c_nbnxnGpuJgroupSize;
-    sm_nextSlotPtr += (NTHREAD_Z * c_nbnxnGpuClusterpairSplit * c_nbnxnGpuJgroupSize * sizeof(*cjs));
+    if (c_preloadCj)
+    {
+        /* the cjs buffer's use expects a base pointer offset for pairs of warps in the j-concurrent execution */
+        cjs += tidxz * c_nbnxnGpuClusterpairSplit * c_nbnxnGpuJgroupSize;
+        sm_nextSlotPtr += (NTHREAD_Z * c_nbnxnGpuClusterpairSplit * c_nbnxnGpuJgroupSize * sizeof(*cjs));
+    }
     /*********************************************************************/
 
 
@@ -211,12 +219,15 @@ nbnxn_kernel_prune_cuda<false>(const NBAtomDataGpu, const NBParamGpu, const Nbnx
 
         if (imaskCheck)
         {
-            /* Pre-load cj into shared memory on both warps separately */
-            if ((tidxj == 0 || tidxj == 4) && tidxi < c_nbnxnGpuJgroupSize)
+            if (c_preloadCj)
             {
-                cjs[tidxi + tidxj * c_nbnxnGpuJgroupSize / c_splitClSize] = pl_cj4[j4].cj[tidxi];
+                /* Pre-load cj into shared memory on both warps separately */
+                if ((tidxj == 0 || tidxj == 4) && tidxi < c_nbnxnGpuJgroupSize)
+                {
+                    cjs[tidxi + tidxj * c_nbnxnGpuJgroupSize / c_splitClSize] = pl_cj4[j4].cj[tidxi];
+                }
+                __syncwarp(c_fullWarpMask);
             }
-            __syncwarp(c_fullWarpMask);
 
 #    pragma unroll 4
             for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
@@ -224,8 +235,8 @@ nbnxn_kernel_prune_cuda<false>(const NBAtomDataGpu, const NBParamGpu, const Nbnx
                 if (imaskCheck & (superClInteractionMask << (jm * c_nbnxnGpuNumClusterPerSupercluster)))
                 {
                     unsigned int mask_ji = (1U << (jm * c_nbnxnGpuNumClusterPerSupercluster));
-
-                    int cj = cjs[jm + (tidxj & 4) * c_nbnxnGpuJgroupSize / c_splitClSize];
+                    int cj = c_preloadCj ? cjs[jm + (tidxj & 4) * c_nbnxnGpuJgroupSize / c_splitClSize]
+                                         : pl_cj4[j4].cj[jm];
                     int aj = cj * c_clSize + tidxj;
 
                     /* load j atom data */
@@ -274,8 +285,11 @@ nbnxn_kernel_prune_cuda<false>(const NBAtomDataGpu, const NBParamGpu, const Nbnx
             /* update the imask with only the pairs up to rlistInner */
             plist.cj4[j4].imei[widx].imask = imaskNew;
         }
-        // avoid shared memory WAR hazards between loop iterations
-        __syncwarp(c_fullWarpMask);
+        if (c_preloadCj)
+        {
+            // avoid shared memory WAR hazards on sm_cjs between loop iterations
+            __syncwarp(c_fullWarpMask);
+        }
     }
 }
 #endif /* FUNCTION_DECLARATION_ONLY */
index deeb17a8faa6b01c8fd113bf4389125b42d7ebc3..cf0148ab89500ee21483e29c2dbc499952d23e97 100644 (file)
@@ -75,7 +75,6 @@ static const unsigned __device__ superClInteractionMask =
 static const float __device__ c_oneSixth    = 0.16666667F;
 static const float __device__ c_oneTwelveth = 0.08333333F;
 
-
 /*! Convert LJ sigma,epsilon parameters to C6,C12. */
 static __forceinline__ __device__ void
 convert_sigma_epsilon_to_c6_c12(const float sigma, const float epsilon, float* c6, float* c12)