Redesigned SIMD module and unit tests.
[alexxy/gromacs.git] / src / gromacs / mdlib / nbnxn_search.c
index 457db913939dfb617239c95e2f79861b8ad472ca..2f93f3036cfcb8b151579576f955cff11f58aed0 100644 (file)
 #include "gromacs/fileio/gmxfio.h"
 
 #ifdef NBNXN_SEARCH_BB_SIMD4
-/* We use 4-wide SIMD for bounding box calculations */
+/* Always use 4-wide SIMD for bounding box calculations */
 
-#ifndef GMX_DOUBLE
+#    ifndef GMX_DOUBLE
 /* Single precision BBs + coordinates, we can also load coordinates with SIMD */
-#define NBNXN_SEARCH_SIMD4_FLOAT_X_BB
-#endif
+#        define NBNXN_SEARCH_SIMD4_FLOAT_X_BB
+#    endif
 
-#if defined NBNXN_SEARCH_SIMD4_FLOAT_X_BB && (GPU_NSUBCELL == 4 || GPU_NSUBCELL == 8)
+#    if defined NBNXN_SEARCH_SIMD4_FLOAT_X_BB && (GPU_NSUBCELL == 4 || GPU_NSUBCELL == 8)
 /* Store bounding boxes with x, y and z coordinates in packs of 4 */
-#define NBNXN_PBB_SIMD4
-#endif
+#        define NBNXN_PBB_SIMD4
+#    endif
 
 /* The packed bounding box coordinate stride is always set to 4.
  * With AVX we could use 8, but that turns out not to be faster.
  */
-#define STRIDE_PBB        4
-#define STRIDE_PBB_2LOG   2
+#    define STRIDE_PBB        4
+#    define STRIDE_PBB_2LOG   2
 
 #endif /* NBNXN_SEARCH_BB_SIMD4 */
 
 #define X_IND_CI_SIMD_2XNN(ci) X_IND_CI_J8(ci)
 #define X_IND_CJ_SIMD_2XNN(cj) X_IND_CJ_J8(cj)
 #else
-#error "unsupported GMX_NBNXN_SIMD_WIDTH"
+#error "unsupported GMX_SIMD_REAL_WIDTH"
 #endif
 #endif
 #endif
@@ -808,20 +808,20 @@ static void calc_bounding_box_x_x4_halves(int na, const real *x,
          * so we don't need to treat special cases in the rest of the code.
          */
 #ifdef NBNXN_SEARCH_BB_SIMD4
-        gmx_simd4_store_r(&bbj[1].lower[0], gmx_simd4_load_bb_pr(&bbj[0].lower[0]));
-        gmx_simd4_store_r(&bbj[1].upper[0], gmx_simd4_load_bb_pr(&bbj[0].upper[0]));
+        gmx_simd4_store_f(&bbj[1].lower[0], gmx_simd4_load_f(&bbj[0].lower[0]));
+        gmx_simd4_store_f(&bbj[1].upper[0], gmx_simd4_load_f(&bbj[0].upper[0]));
 #else
         bbj[1] = bbj[0];
 #endif
     }
 
 #ifdef NBNXN_SEARCH_BB_SIMD4
-    gmx_simd4_store_r(&bb->lower[0],
-                      gmx_simd4_min_r(gmx_simd4_load_bb_pr(&bbj[0].lower[0]),
-                                      gmx_simd4_load_bb_pr(&bbj[1].lower[0])));
-    gmx_simd4_store_r(&bb->upper[0],
-                      gmx_simd4_max_r(gmx_simd4_load_bb_pr(&bbj[0].upper[0]),
-                                      gmx_simd4_load_bb_pr(&bbj[1].upper[0])));
+    gmx_simd4_store_f(&bb->lower[0],
+                      gmx_simd4_min_f(gmx_simd4_load_f(&bbj[0].lower[0]),
+                                      gmx_simd4_load_f(&bbj[1].lower[0])));
+    gmx_simd4_store_f(&bb->upper[0],
+                      gmx_simd4_max_f(gmx_simd4_load_f(&bbj[0].upper[0]),
+                                      gmx_simd4_load_f(&bbj[1].upper[0])));
 #else
     {
         int i;
@@ -877,23 +877,23 @@ static void calc_bounding_box_xxxx(int na, int stride, const real *x, float *bb)
 /* Coordinate order xyz?, bb order xyz0 */
 static void calc_bounding_box_simd4(int na, const float *x, nbnxn_bb_t *bb)
 {
-    gmx_simd4_real_t bb_0_S, bb_1_S;
-    gmx_simd4_real_t x_S;
+    gmx_simd4_float_t bb_0_S, bb_1_S;
+    gmx_simd4_float_t x_S;
 
-    int              i;
+    int               i;
 
-    bb_0_S = gmx_simd4_load_bb_pr(x);
+    bb_0_S = gmx_simd4_load_f(x);
     bb_1_S = bb_0_S;
 
     for (i = 1; i < na; i++)
     {
-        x_S    = gmx_simd4_load_bb_pr(x+i*NNBSBB_C);
-        bb_0_S = gmx_simd4_min_r(bb_0_S, x_S);
-        bb_1_S = gmx_simd4_max_r(bb_1_S, x_S);
+        x_S    = gmx_simd4_load_f(x+i*NNBSBB_C);
+        bb_0_S = gmx_simd4_min_f(bb_0_S, x_S);
+        bb_1_S = gmx_simd4_max_f(bb_1_S, x_S);
     }
 
-    gmx_simd4_store_r(&bb->lower[0], bb_0_S);
-    gmx_simd4_store_r(&bb->upper[0], bb_1_S);
+    gmx_simd4_store_f(&bb->lower[0], bb_0_S);
+    gmx_simd4_store_f(&bb->upper[0], bb_1_S);
 }
 
 /* Coordinate order xyz?, bb order xxxxyyyyzzzz */
@@ -928,14 +928,14 @@ static void combine_bounding_box_pairs(nbnxn_grid_t *grid, const nbnxn_bb_t *bb)
         for (c2 = sc2; c2 < sc2+nc2; c2++)
         {
 #ifdef NBNXN_SEARCH_BB_SIMD4
-            gmx_simd4_real_t min_S, max_S;
-
-            min_S = gmx_simd4_min_r(gmx_simd4_load_bb_pr(&bb[c2*2+0].lower[0]),
-                                    gmx_simd4_load_bb_pr(&bb[c2*2+1].lower[0]));
-            max_S = gmx_simd4_max_r(gmx_simd4_load_bb_pr(&bb[c2*2+0].upper[0]),
-                                    gmx_simd4_load_bb_pr(&bb[c2*2+1].upper[0]));
-            gmx_simd4_store_r(&grid->bbj[c2].lower[0], min_S);
-            gmx_simd4_store_r(&grid->bbj[c2].upper[0], max_S);
+            gmx_simd4_float_t min_S, max_S;
+
+            min_S = gmx_simd4_min_f(gmx_simd4_load_f(&bb[c2*2+0].lower[0]),
+                                    gmx_simd4_load_f(&bb[c2*2+1].lower[0]));
+            max_S = gmx_simd4_max_f(gmx_simd4_load_f(&bb[c2*2+0].upper[0]),
+                                    gmx_simd4_load_f(&bb[c2*2+1].upper[0]));
+            gmx_simd4_store_f(&grid->bbj[c2].lower[0], min_S);
+            gmx_simd4_store_f(&grid->bbj[c2].upper[0], max_S);
 #else
             for (j = 0; j < NNBSBB_C; j++)
             {
@@ -2075,74 +2075,74 @@ static float subc_bb_dist2(int si, const nbnxn_bb_t *bb_i_ci,
 static float subc_bb_dist2_simd4(int si, const nbnxn_bb_t *bb_i_ci,
                                  int csj, const nbnxn_bb_t *bb_j_all)
 {
-    gmx_simd4_real_t bb_i_S0, bb_i_S1;
-    gmx_simd4_real_t bb_j_S0, bb_j_S1;
-    gmx_simd4_real_t dl_S;
-    gmx_simd4_real_t dh_S;
-    gmx_simd4_real_t dm_S;
-    gmx_simd4_real_t dm0_S;
+    gmx_simd4_float_t bb_i_S0, bb_i_S1;
+    gmx_simd4_float_t bb_j_S0, bb_j_S1;
+    gmx_simd4_float_t dl_S;
+    gmx_simd4_float_t dh_S;
+    gmx_simd4_float_t dm_S;
+    gmx_simd4_float_t dm0_S;
 
-    bb_i_S0 = gmx_simd4_load_bb_pr(&bb_i_ci[si].lower[0]);
-    bb_i_S1 = gmx_simd4_load_bb_pr(&bb_i_ci[si].upper[0]);
-    bb_j_S0 = gmx_simd4_load_bb_pr(&bb_j_all[csj].lower[0]);
-    bb_j_S1 = gmx_simd4_load_bb_pr(&bb_j_all[csj].upper[0]);
+    bb_i_S0 = gmx_simd4_load_f(&bb_i_ci[si].lower[0]);
+    bb_i_S1 = gmx_simd4_load_f(&bb_i_ci[si].upper[0]);
+    bb_j_S0 = gmx_simd4_load_f(&bb_j_all[csj].lower[0]);
+    bb_j_S1 = gmx_simd4_load_f(&bb_j_all[csj].upper[0]);
 
-    dl_S    = gmx_simd4_sub_r(bb_i_S0, bb_j_S1);
-    dh_S    = gmx_simd4_sub_r(bb_j_S0, bb_i_S1);
+    dl_S    = gmx_simd4_sub_f(bb_i_S0, bb_j_S1);
+    dh_S    = gmx_simd4_sub_f(bb_j_S0, bb_i_S1);
 
-    dm_S    = gmx_simd4_max_r(dl_S, dh_S);
-    dm0_S   = gmx_simd4_max_r(dm_S, gmx_simd4_setzero_r());
+    dm_S    = gmx_simd4_max_f(dl_S, dh_S);
+    dm0_S   = gmx_simd4_max_f(dm_S, gmx_simd4_setzero_f());
 
-    return gmx_simd4_dotproduct3_r(dm0_S, dm0_S);
+    return gmx_simd4_dotproduct3_f(dm0_S, dm0_S);
 }
 
 /* Calculate bb bounding distances of bb_i[si,...,si+3] and store them in d2 */
 #define SUBC_BB_DIST2_SIMD4_XXXX_INNER(si, bb_i, d2) \
     {                                                \
-        int              shi;                                  \
+        int               shi;                                  \
                                                  \
-        gmx_simd4_real_t dx_0, dy_0, dz_0;                       \
-        gmx_simd4_real_t dx_1, dy_1, dz_1;                       \
+        gmx_simd4_float_t dx_0, dy_0, dz_0;                    \
+        gmx_simd4_float_t dx_1, dy_1, dz_1;                    \
                                                  \
-        gmx_simd4_real_t mx, my, mz;                             \
-        gmx_simd4_real_t m0x, m0y, m0z;                          \
+        gmx_simd4_float_t mx, my, mz;                          \
+        gmx_simd4_float_t m0x, m0y, m0z;                       \
                                                  \
-        gmx_simd4_real_t d2x, d2y, d2z;                          \
-        gmx_simd4_real_t d2s, d2t;                              \
+        gmx_simd4_float_t d2x, d2y, d2z;                       \
+        gmx_simd4_float_t d2s, d2t;                            \
                                                  \
         shi = si*NNBSBB_D*DIM;                       \
                                                  \
-        xi_l = gmx_simd4_load_bb_pr(bb_i+shi+0*STRIDE_PBB);   \
-        yi_l = gmx_simd4_load_bb_pr(bb_i+shi+1*STRIDE_PBB);   \
-        zi_l = gmx_simd4_load_bb_pr(bb_i+shi+2*STRIDE_PBB);   \
-        xi_h = gmx_simd4_load_bb_pr(bb_i+shi+3*STRIDE_PBB);   \
-        yi_h = gmx_simd4_load_bb_pr(bb_i+shi+4*STRIDE_PBB);   \
-        zi_h = gmx_simd4_load_bb_pr(bb_i+shi+5*STRIDE_PBB);   \
+        xi_l = gmx_simd4_load_f(bb_i+shi+0*STRIDE_PBB);   \
+        yi_l = gmx_simd4_load_f(bb_i+shi+1*STRIDE_PBB);   \
+        zi_l = gmx_simd4_load_f(bb_i+shi+2*STRIDE_PBB);   \
+        xi_h = gmx_simd4_load_f(bb_i+shi+3*STRIDE_PBB);   \
+        yi_h = gmx_simd4_load_f(bb_i+shi+4*STRIDE_PBB);   \
+        zi_h = gmx_simd4_load_f(bb_i+shi+5*STRIDE_PBB);   \
                                                  \
-        dx_0 = gmx_simd4_sub_r(xi_l, xj_h);                \
-        dy_0 = gmx_simd4_sub_r(yi_l, yj_h);                \
-        dz_0 = gmx_simd4_sub_r(zi_l, zj_h);                \
+        dx_0 = gmx_simd4_sub_f(xi_l, xj_h);                 \
+        dy_0 = gmx_simd4_sub_f(yi_l, yj_h);                 \
+        dz_0 = gmx_simd4_sub_f(zi_l, zj_h);                 \
                                                  \
-        dx_1 = gmx_simd4_sub_r(xj_l, xi_h);                \
-        dy_1 = gmx_simd4_sub_r(yj_l, yi_h);                \
-        dz_1 = gmx_simd4_sub_r(zj_l, zi_h);                \
+        dx_1 = gmx_simd4_sub_f(xj_l, xi_h);                 \
+        dy_1 = gmx_simd4_sub_f(yj_l, yi_h);                 \
+        dz_1 = gmx_simd4_sub_f(zj_l, zi_h);                 \
                                                  \
-        mx   = gmx_simd4_max_r(dx_0, dx_1);                \
-        my   = gmx_simd4_max_r(dy_0, dy_1);                \
-        mz   = gmx_simd4_max_r(dz_0, dz_1);                \
+        mx   = gmx_simd4_max_f(dx_0, dx_1);                 \
+        my   = gmx_simd4_max_f(dy_0, dy_1);                 \
+        mz   = gmx_simd4_max_f(dz_0, dz_1);                 \
                                                  \
-        m0x  = gmx_simd4_max_r(mx, zero);                  \
-        m0y  = gmx_simd4_max_r(my, zero);                  \
-        m0z  = gmx_simd4_max_r(mz, zero);                  \
+        m0x  = gmx_simd4_max_f(mx, zero);                   \
+        m0y  = gmx_simd4_max_f(my, zero);                   \
+        m0z  = gmx_simd4_max_f(mz, zero);                   \
                                                  \
-        d2x  = gmx_simd4_mul_r(m0x, m0x);                  \
-        d2y  = gmx_simd4_mul_r(m0y, m0y);                  \
-        d2z  = gmx_simd4_mul_r(m0z, m0z);                  \
+        d2x  = gmx_simd4_mul_f(m0x, m0x);                   \
+        d2y  = gmx_simd4_mul_f(m0y, m0y);                   \
+        d2z  = gmx_simd4_mul_f(m0z, m0z);                   \
                                                  \
-        d2s  = gmx_simd4_add_r(d2x, d2y);                  \
-        d2t  = gmx_simd4_add_r(d2s, d2z);                  \
+        d2s  = gmx_simd4_add_f(d2x, d2y);                   \
+        d2t  = gmx_simd4_add_f(d2s, d2z);                   \
                                                  \
-        gmx_simd4_store_r(d2+si, d2t);                     \
+        gmx_simd4_store_f(d2+si, d2t);                      \
     }
 
 /* 4-wide SIMD code for nsi bb distances for bb format xxxxyyyyzzzz */
@@ -2150,21 +2150,21 @@ static void subc_bb_dist2_simd4_xxxx(const float *bb_j,
                                      int nsi, const float *bb_i,
                                      float *d2)
 {
-    gmx_simd4_real_t xj_l, yj_l, zj_l;
-    gmx_simd4_real_t xj_h, yj_h, zj_h;
-    gmx_simd4_real_t xi_l, yi_l, zi_l;
-    gmx_simd4_real_t xi_h, yi_h, zi_h;
+    gmx_simd4_float_t xj_l, yj_l, zj_l;
+    gmx_simd4_float_t xj_h, yj_h, zj_h;
+    gmx_simd4_float_t xi_l, yi_l, zi_l;
+    gmx_simd4_float_t xi_h, yi_h, zi_h;
 
-    gmx_simd4_real_t zero;
+    gmx_simd4_float_t zero;
 
-    zero = gmx_simd4_setzero_r();
+    zero = gmx_simd4_setzero_f();
 
-    xj_l = gmx_simd4_set1_r(bb_j[0*STRIDE_PBB]);
-    yj_l = gmx_simd4_set1_r(bb_j[1*STRIDE_PBB]);
-    zj_l = gmx_simd4_set1_r(bb_j[2*STRIDE_PBB]);
-    xj_h = gmx_simd4_set1_r(bb_j[3*STRIDE_PBB]);
-    yj_h = gmx_simd4_set1_r(bb_j[4*STRIDE_PBB]);
-    zj_h = gmx_simd4_set1_r(bb_j[5*STRIDE_PBB]);
+    xj_l = gmx_simd4_set1_f(bb_j[0*STRIDE_PBB]);
+    yj_l = gmx_simd4_set1_f(bb_j[1*STRIDE_PBB]);
+    zj_l = gmx_simd4_set1_f(bb_j[2*STRIDE_PBB]);
+    xj_h = gmx_simd4_set1_f(bb_j[3*STRIDE_PBB]);
+    yj_h = gmx_simd4_set1_f(bb_j[4*STRIDE_PBB]);
+    zj_h = gmx_simd4_set1_f(bb_j[5*STRIDE_PBB]);
 
     /* Here we "loop" over si (0,STRIDE_PBB) from 0 to nsi with step STRIDE_PBB.
      * But as we know the number of iterations is 1 or 2, we unroll manually.
@@ -2211,14 +2211,6 @@ static gmx_bool subc_in_range_x(int na_c,
 }
 
 #ifdef NBNXN_SEARCH_SIMD4_FLOAT_X_BB
-/* When we make seperate single/double precision SIMD vector operation
- * include files, this function should be moved there (also using FMA).
- */
-static inline gmx_simd4_real_t
-gmx_simd4_calc_rsq_r(gmx_simd4_real_t x, gmx_simd4_real_t y, gmx_simd4_real_t z)
-{
-    return gmx_simd4_add_r( gmx_simd4_add_r( gmx_simd4_mul_r(x, x), gmx_simd4_mul_r(y, y) ), gmx_simd4_mul_r(z, z) );
-}
 
 /* 4-wide SIMD function which determines if any atom pair between two cells,
  * both with 8 atoms, is within distance sqrt(rl2).
@@ -2240,12 +2232,12 @@ static gmx_bool subc_in_range_simd4(int na_c,
     rc2_S   = gmx_simd4_set1_r(rl2);
 
     dim_stride = NBNXN_GPU_CLUSTER_SIZE/STRIDE_PBB*DIM;
-    ix_S0      = gmx_simd4_load_bb_pr(x_i+(si*dim_stride+0)*STRIDE_PBB);
-    iy_S0      = gmx_simd4_load_bb_pr(x_i+(si*dim_stride+1)*STRIDE_PBB);
-    iz_S0      = gmx_simd4_load_bb_pr(x_i+(si*dim_stride+2)*STRIDE_PBB);
-    ix_S1      = gmx_simd4_load_bb_pr(x_i+(si*dim_stride+3)*STRIDE_PBB);
-    iy_S1      = gmx_simd4_load_bb_pr(x_i+(si*dim_stride+4)*STRIDE_PBB);
-    iz_S1      = gmx_simd4_load_bb_pr(x_i+(si*dim_stride+5)*STRIDE_PBB);
+    ix_S0      = gmx_simd4_load_r(x_i+(si*dim_stride+0)*STRIDE_PBB);
+    iy_S0      = gmx_simd4_load_r(x_i+(si*dim_stride+1)*STRIDE_PBB);
+    iz_S0      = gmx_simd4_load_r(x_i+(si*dim_stride+2)*STRIDE_PBB);
+    ix_S1      = gmx_simd4_load_r(x_i+(si*dim_stride+3)*STRIDE_PBB);
+    iy_S1      = gmx_simd4_load_r(x_i+(si*dim_stride+4)*STRIDE_PBB);
+    iz_S1      = gmx_simd4_load_r(x_i+(si*dim_stride+5)*STRIDE_PBB);
 
     /* We loop from the outer to the inner particles to maximize
      * the chance that we find a pair in range quickly and return.