implemented plain-C SIMD macros for reference
[alexxy/gromacs.git] / src / mdlib / nbnxn_search.c
index 16cdd1fea60fd439abba01f54ae34c661f902355..8c5a93d8cb97e58e847a1faa0689be088e8d3041 100644 (file)
 #include "vec.h"
 #include "pbc.h"
 #include "nbnxn_consts.h"
+/* nbnxn_internal.h included gmx_simd_macros.h */
 #include "nbnxn_internal.h"
+#ifdef GMX_NBNXN_SIMD
+#include "gmx_simd_vec.h"
+#endif
 #include "nbnxn_atomdata.h"
 #include "nbnxn_search.h"
 #include "gmx_cyclecounter.h"
 #define X_IND_CJ_J8(cj)  ((cj)*STRIDE_P8)
 
 /* The j-cluster size is matched to the SIMD width */
-#if GMX_NBNXN_SIMD_BITWIDTH == 128
-#ifdef GMX_DOUBLE
+#if GMX_SIMD_WIDTH_HERE == 2
 #define CI_TO_CJ_SIMD_4XN(ci)  CI_TO_CJ_J2(ci)
 #define X_IND_CI_SIMD_4XN(ci)  X_IND_CI_J2(ci)
 #define X_IND_CJ_SIMD_4XN(cj)  X_IND_CJ_J2(cj)
 #else
-#define CI_TO_CJ_SIMD_4XN(ci)  CI_TO_CJ_J4(ci)
-#define X_IND_CI_SIMD_4XN(ci)  X_IND_CI_J4(ci)
-#define X_IND_CJ_SIMD_4XN(cj)  X_IND_CJ_J4(cj)
-#endif
-#else
-#if GMX_NBNXN_SIMD_BITWIDTH == 256
-#ifdef GMX_DOUBLE
+#if GMX_SIMD_WIDTH_HERE == 4
 #define CI_TO_CJ_SIMD_4XN(ci)  CI_TO_CJ_J4(ci)
 #define X_IND_CI_SIMD_4XN(ci)  X_IND_CI_J4(ci)
 #define X_IND_CJ_SIMD_4XN(cj)  X_IND_CJ_J4(cj)
 #else
+#if GMX_SIMD_WIDTH_HERE == 8
 #define CI_TO_CJ_SIMD_4XN(ci)  CI_TO_CJ_J8(ci)
 #define X_IND_CI_SIMD_4XN(ci)  X_IND_CI_J8(ci)
 #define X_IND_CJ_SIMD_4XN(cj)  X_IND_CJ_J8(cj)
 #define CI_TO_CJ_SIMD_2XNN(ci) CI_TO_CJ_J4(ci)
 #define X_IND_CI_SIMD_2XNN(ci) X_IND_CI_J4(ci)
 #define X_IND_CJ_SIMD_2XNN(cj) X_IND_CJ_J4(cj)
-#endif
+#else
+#if GMX_SIMD_WIDTH_HERE == 16
+#define CI_TO_CJ_SIMD_2XNN(ci) CI_TO_CJ_J8(ci)
+#define X_IND_CI_SIMD_2XNN(ci) X_IND_CI_J8(ci)
+#define X_IND_CJ_SIMD_2XNN(cj) X_IND_CJ_J8(cj)
 #else
 #error "unsupported GMX_NBNXN_SIMD_WIDTH"
 #endif
 #endif
+#endif
+#endif
 
 #endif /* GMX_NBNXN_SIMD */
 
 
-/* Interaction masks for 4xN atom interactions.
- * Bit i*CJ_SIZE + j tells if atom i and j interact.
- */
-/* All interaction mask is the same for all kernels */
-#define NBNXN_INT_MASK_ALL        0xffffffff
-/* 4x4 kernel diagonal mask */
-#define NBNXN_INT_MASK_DIAG       0x08ce
-/* 4x2 kernel diagonal masks */
-#define NBNXN_INT_MASK_DIAG_J2_0  0x0002
-#define NBNXN_INT_MASK_DIAG_J2_1  0x002F
-/* 4x8 kernel diagonal masks */
-#define NBNXN_INT_MASK_DIAG_J8_0  0xf0f8fcfe
-#define NBNXN_INT_MASK_DIAG_J8_1  0x0080c0e0
-
-
 #ifdef NBNXN_SEARCH_BB_SSE
 /* Store bounding boxes corners as quadruplets: xxxxyyyyzzzz */
 #define NBNXN_BBXXXX
@@ -293,7 +282,7 @@ int nbnxn_kernel_to_cj_size(int nb_kernel_type)
     int cj_size          = 0;
 
 #ifdef GMX_NBNXN_SIMD
-    nbnxn_simd_width = GMX_NBNXN_SIMD_BITWIDTH/(sizeof(real)*8);
+    nbnxn_simd_width = GMX_SIMD_WIDTH_HERE;
 #endif
 
     switch (nb_kernel_type)
@@ -810,12 +799,14 @@ static void calc_bounding_box_x_x8(int na, const real *x, float *bb)
     bb[BBU_Z] = R2F_U(zh);
 }
 
-#ifdef NBNXN_SEARCH_BB_SSE
-
 /* Packed coordinates, bb order xyz0 */
 static void calc_bounding_box_x_x4_halves(int na, const real *x,
                                           float *bb, float *bbj)
 {
+#ifndef NBNXN_SEARCH_BB_SSE
+    int i;
+#endif
+
     calc_bounding_box_x_x4(min(na, 2), x, bbj);
 
     if (na > 2)
@@ -827,16 +818,33 @@ static void calc_bounding_box_x_x4_halves(int na, const real *x,
         /* Set the "empty" bounding box to the same as the first one,
          * so we don't need to treat special cases in the rest of the code.
          */
+#ifdef NBNXN_SEARCH_BB_SSE
         _mm_store_ps(bbj+NNBSBB_B, _mm_load_ps(bbj));
         _mm_store_ps(bbj+NNBSBB_B+NNBSBB_C, _mm_load_ps(bbj+NNBSBB_C));
+#else
+        for (i = 0; i < NNBSBB_B; i++)
+        {
+            bbj[NNBSBB_B + i] = bbj[i];
+        }
+#endif
     }
 
+#ifdef NBNXN_SEARCH_BB_SSE
     _mm_store_ps(bb, _mm_min_ps(_mm_load_ps(bbj),
                                 _mm_load_ps(bbj+NNBSBB_B)));
     _mm_store_ps(bb+NNBSBB_C, _mm_max_ps(_mm_load_ps(bbj+NNBSBB_C),
                                          _mm_load_ps(bbj+NNBSBB_B+NNBSBB_C)));
+#else
+    for (i = 0; i < NNBSBB_C; i++)
+    {
+        bb[           i] = min(bbj[           i], bbj[NNBSBB_B +            i]);
+        bb[NNBSBB_C + i] = max(bbj[NNBSBB_C + i], bbj[NNBSBB_B + NNBSBB_C + i]);
+    }
+#endif
 }
 
+#ifdef NBNXN_SEARCH_BB_SSE
+
 /* Coordinate order xyz, bb order xxxxyyyyzzzz */
 static void calc_bounding_box_xxxx(int na, int stride, const real *x, float *bb)
 {
@@ -913,13 +921,11 @@ static void calc_bounding_box_xxxx_sse(int na, const float *x,
 
 #endif /* NBNXN_SEARCH_SSE_SINGLE */
 
-#ifdef NBNXN_SEARCH_BB_SSE
 
 /* Combines pairs of consecutive bounding boxes */
 static void combine_bounding_box_pairs(nbnxn_grid_t *grid, const float *bb)
 {
     int    i, j, sc2, nc2, c2;
-    __m128 min_SSE, max_SSE;
 
     for (i = 0; i < grid->ncx*grid->ncy; i++)
     {
@@ -929,12 +935,24 @@ static void combine_bounding_box_pairs(nbnxn_grid_t *grid, const float *bb)
         nc2 = (grid->cxy_na[i]+3)>>(2+1);
         for (c2 = sc2; c2 < sc2+nc2; c2++)
         {
+#ifdef NBNXN_SEARCH_BB_SSE
+            __m128 min_SSE, max_SSE;
+
             min_SSE = _mm_min_ps(_mm_load_ps(bb+(c2*4+0)*NNBSBB_C),
                                  _mm_load_ps(bb+(c2*4+2)*NNBSBB_C));
             max_SSE = _mm_max_ps(_mm_load_ps(bb+(c2*4+1)*NNBSBB_C),
                                  _mm_load_ps(bb+(c2*4+3)*NNBSBB_C));
             _mm_store_ps(grid->bbj+(c2*2+0)*NNBSBB_C, min_SSE);
             _mm_store_ps(grid->bbj+(c2*2+1)*NNBSBB_C, max_SSE);
+#else
+            for (j = 0; j < NNBSBB_C; j++)
+            {
+                grid->bbj[(c2*2+0)*NNBSBB_C+j] = min(bb[(c2*4+0)*NNBSBB_C+j],
+                                                     bb[(c2*4+2)*NNBSBB_C+j]);
+                grid->bbj[(c2*2+1)*NNBSBB_C+j] = max(bb[(c2*4+1)*NNBSBB_C+j],
+                                                     bb[(c2*4+3)*NNBSBB_C+j]);
+            }
+#endif
         }
         if (((grid->cxy_na[i]+3)>>2) & 1)
         {
@@ -948,8 +966,6 @@ static void combine_bounding_box_pairs(nbnxn_grid_t *grid, const float *bb)
     }
 }
 
-#endif
-
 
 /* Prints the average bb size, used for debug output */
 static void print_bbsizes_simple(FILE                *fp,
@@ -1147,7 +1163,7 @@ void fill_cell(const nbnxn_search_t nbs,
         offset = ((a0 - grid->cell0*grid->na_sc)>>grid->na_c_2log)*NNBSBB_B;
         bb_ptr = grid->bb + offset;
 
-#if defined GMX_DOUBLE && defined NBNXN_SEARCH_BB_SSE
+#if defined GMX_NBNXN_SIMD && GMX_SIMD_WIDTH_HERE == 2
         if (2*grid->na_cj == grid->na_c)
         {
             calc_bounding_box_x_x4_halves(na, nbat->x+X4_IND_A(a0), bb_ptr,
@@ -1652,12 +1668,10 @@ static void calc_cell_indices(const nbnxn_search_t nbs,
         }
     }
 
-#ifdef NBNXN_SEARCH_BB_SSE
     if (grid->bSimple && nbat->XFormat == nbatX8)
     {
         combine_bounding_box_pairs(grid, grid->bb);
     }
-#endif
 
     if (!grid->bSimple)
     {
@@ -1926,12 +1940,10 @@ void nbnxn_grid_add_simple(nbnxn_search_t    nbs,
         }
     }
 
-#ifdef NBNXN_SEARCH_BB_SSE
     if (grid->bSimple && nbat->XFormat == nbatX8)
     {
         combine_bounding_box_pairs(grid, grid->bb_simple);
     }
-#endif
 }
 
 void nbnxn_get_ncells(nbnxn_search_t nbs, int *ncx, int *ncy)
@@ -2067,8 +2079,7 @@ static float subc_bb_dist2(int si, const float *bb_i_ci,
 #ifdef NBNXN_SEARCH_BB_SSE
 
 /* SSE code for bb distance for bb format xyz0 */
-static float subc_bb_dist2_sse(int na_c,
-                               int si, const float *bb_i_ci,
+static float subc_bb_dist2_sse(int si, const float *bb_i_ci,
                                int csj, const float *bb_j_all)
 {
     const float *bb_i, *bb_j;
@@ -2231,8 +2242,20 @@ static gmx_bool subc_in_range_x(int na_c,
     return FALSE;
 }
 
+#ifdef NBNXN_SEARCH_SSE_SINGLE
+/* When we make seperate single/double precision SIMD vector operation
+ * include files, this function should be moved there (also using FMA).
+ */
+static inline __m128
+gmx_mm_calc_rsq_ps(__m128 x, __m128 y, __m128 z)
+{
+    return _mm_add_ps( _mm_add_ps( _mm_mul_ps(x, x), _mm_mul_ps(y, y) ), _mm_mul_ps(z, z) );
+}
+#endif
+
 /* SSE function which determines if any atom pair between two cells,
  * both with 8 atoms, is within distance sqrt(rl2).
+ * Not performance critical, so only uses plain SSE.
  */
 static gmx_bool subc_in_range_sse8(int na_c,
                                    int si, const real *x_i,
@@ -2430,7 +2453,7 @@ static void set_no_excls(nbnxn_excl_t *excl)
     for (t = 0; t < WARP_SIZE; t++)
     {
         /* Turn all interaction bits on */
-        excl->pair[t] = NBNXN_INT_MASK_ALL;
+        excl->pair[t] = NBNXN_INTERACTION_MASK_ALL;
     }
 }
 
@@ -2577,7 +2600,7 @@ static void print_nblist_statistics_simple(FILE *fp, const nbnxn_pairlist_t *nbl
 
         j = nbl->ci[i].cj_ind_start;
         while (j < nbl->ci[i].cj_ind_end &&
-               nbl->cj[j].excl != NBNXN_INT_MASK_ALL)
+               nbl->cj[j].excl != NBNXN_INTERACTION_MASK_ALL)
         {
             npexcl++;
             j++;
@@ -2720,43 +2743,44 @@ static void set_self_and_newton_excls_supersub(nbnxn_pairlist_t *nbl,
 /* Returns a diagonal or off-diagonal interaction mask for plain C lists */
 static unsigned int get_imask(gmx_bool rdiag, int ci, int cj)
 {
-    return (rdiag && ci == cj ? NBNXN_INT_MASK_DIAG : NBNXN_INT_MASK_ALL);
+    return (rdiag && ci == cj ? NBNXN_INTERACTION_MASK_DIAG : NBNXN_INTERACTION_MASK_ALL);
 }
 
-/* Returns a diagonal or off-diagonal interaction mask for SIMD128 lists */
-static unsigned int get_imask_simd128(gmx_bool rdiag, int ci, int cj)
+/* Returns a diagonal or off-diagonal interaction mask for cj-size=2 */
+static unsigned int get_imask_simd_j2(gmx_bool rdiag, int ci, int cj)
 {
-#ifndef GMX_DOUBLE /* cj-size = 4 */
-    return (rdiag && ci == cj ? NBNXN_INT_MASK_DIAG : NBNXN_INT_MASK_ALL);
-#else              /* cj-size = 2 */
-    return (rdiag && ci*2 == cj ? NBNXN_INT_MASK_DIAG_J2_0 :
-            (rdiag && ci*2+1 == cj ? NBNXN_INT_MASK_DIAG_J2_1 :
-             NBNXN_INT_MASK_ALL));
-#endif
+    return (rdiag && ci*2 == cj ? NBNXN_INTERACTION_MASK_DIAG_J2_0 :
+            (rdiag && ci*2+1 == cj ? NBNXN_INTERACTION_MASK_DIAG_J2_1 :
+             NBNXN_INTERACTION_MASK_ALL));
 }
 
-/* Returns a diagonal or off-diagonal interaction mask for SIMD256 lists */
-static unsigned int get_imask_simd256(gmx_bool rdiag, int ci, int cj)
+/* Returns a diagonal or off-diagonal interaction mask for cj-size=4 */
+static unsigned int get_imask_simd_j4(gmx_bool rdiag, int ci, int cj)
 {
-#ifndef GMX_DOUBLE /* cj-size = 8 */
-    return (rdiag && ci == cj*2 ? NBNXN_INT_MASK_DIAG_J8_0 :
-            (rdiag && ci == cj*2+1 ? NBNXN_INT_MASK_DIAG_J8_1 :
-             NBNXN_INT_MASK_ALL));
-#else              /* cj-size = 4 */
-    return (rdiag && ci == cj ? NBNXN_INT_MASK_DIAG : NBNXN_INT_MASK_ALL);
-#endif
+    return (rdiag && ci == cj ? NBNXN_INTERACTION_MASK_DIAG : NBNXN_INTERACTION_MASK_ALL);
+}
+
+/* Returns a diagonal or off-diagonal interaction mask for cj-size=8 */
+static unsigned int get_imask_simd_j8(gmx_bool rdiag, int ci, int cj)
+{
+    return (rdiag && ci == cj*2 ? NBNXN_INTERACTION_MASK_DIAG_J8_0 :
+            (rdiag && ci == cj*2+1 ? NBNXN_INTERACTION_MASK_DIAG_J8_1 :
+             NBNXN_INTERACTION_MASK_ALL));
 }
 
 #ifdef GMX_NBNXN_SIMD
-#if GMX_NBNXN_SIMD_BITWIDTH == 128
-#define get_imask_simd_4xn  get_imask_simd128
-#else
-#if GMX_NBNXN_SIMD_BITWIDTH == 256
-#define get_imask_simd_4xn  get_imask_simd256
-#define get_imask_simd_2xnn get_imask_simd128
-#else
-#error "unsupported GMX_NBNXN_SIMD_BITWIDTH"
+#if GMX_SIMD_WIDTH_HERE == 2
+#define get_imask_simd_4xn  get_imask_simd_j2
+#endif
+#if GMX_SIMD_WIDTH_HERE == 4
+#define get_imask_simd_4xn  get_imask_simd_j4
+#endif
+#if GMX_SIMD_WIDTH_HERE == 8
+#define get_imask_simd_4xn  get_imask_simd_j8
+#define get_imask_simd_2xnn get_imask_simd_j4
 #endif
+#if GMX_SIMD_WIDTH_HERE == 16
+#define get_imask_simd_2xnn get_imask_simd_j8
 #endif
 #endif
 
@@ -3413,18 +3437,18 @@ static void sort_cj_excl(nbnxn_cj_t *cj, int ncj,
     jnew = 0;
     for (j = 0; j < ncj; j++)
     {
-        if (cj[j].excl != NBNXN_INT_MASK_ALL)
+        if (cj[j].excl != NBNXN_INTERACTION_MASK_ALL)
         {
             work->cj[jnew++] = cj[j];
         }
     }
     /* Check if there are exclusions at all or not just the first entry */
     if (!((jnew == 0) ||
-          (jnew == 1 && cj[0].excl != NBNXN_INT_MASK_ALL)))
+          (jnew == 1 && cj[0].excl != NBNXN_INTERACTION_MASK_ALL)))
     {
         for (j = 0; j < ncj; j++)
         {
-            if (cj[j].excl == NBNXN_INT_MASK_ALL)
+            if (cj[j].excl == NBNXN_INTERACTION_MASK_ALL)
             {
                 work->cj[jnew++] = cj[j];
             }