Redesigned SIMD module and unit tests.
[alexxy/gromacs.git] / src / gromacs / mdlib / nbnxn_kernels / nbnxn_kernel_simd_utils_ref.h
index c7a6e9a6c0e9a8e93c5a99c8bb437bedd383e4c6..e7d38f2cfd5fb5c4c1e66487492614c75efd40b7 100644 (file)
 #ifndef _nbnxn_kernel_simd_utils_ref_h_
 #define _nbnxn_kernel_simd_utils_ref_h_
 
-typedef gmx_simd_ref_epi32      gmx_simd_ref_exclfilter;
+#
+#include "gromacs/simd/simd_math.h"
+
+typedef gmx_simd_int32_t        gmx_simd_ref_exclfilter;
 typedef gmx_simd_ref_exclfilter gmx_exclfilter;
 static const int filter_stride = GMX_SIMD_INT32_WIDTH/GMX_SIMD_REAL_WIDTH;
 
@@ -55,13 +58,13 @@ static const int nbfp_stride = 4;
 /* float/double SIMD register type */
 typedef struct {
     real r[4];
-} gmx_mm_pr4;
+} gmx_simd4_real_t;
 
-static gmx_inline gmx_mm_pr4
-gmx_load_pr4(const real *r)
+static gmx_inline gmx_simd4_real_t
+gmx_simd4_load_r(const real *r)
 {
-    gmx_mm_pr4 a;
-    int        i;
+    gmx_simd4_real_t a;
+    int              i;
 
     for (i = 0; i < 4; i++)
     {
@@ -72,10 +75,10 @@ gmx_load_pr4(const real *r)
 }
 
 static gmx_inline void
-gmx_store_pr4(real *dest, gmx_mm_pr4 src)
+gmx_simd4_store_r(real *dest, gmx_simd4_real_t src)
 {
-    gmx_mm_pr4 a;
-    int        i;
+    gmx_simd4_real_t a;
+    int              i;
 
     for (i = 0; i < 4; i++)
     {
@@ -83,11 +86,11 @@ gmx_store_pr4(real *dest, gmx_mm_pr4 src)
     }
 }
 
-static gmx_inline gmx_mm_pr4
-gmx_add_pr4(gmx_mm_pr4 a, gmx_mm_pr4 b)
+static gmx_inline gmx_simd4_real_t
+gmx_simd4_add_r(gmx_simd4_real_t a, gmx_simd4_real_t b)
 {
-    gmx_mm_pr4 c;
-    int        i;
+    gmx_simd4_real_t c;
+    int              i;
 
     for (i = 0; i < 4; i++)
     {
@@ -96,6 +99,13 @@ gmx_add_pr4(gmx_mm_pr4 a, gmx_mm_pr4 b)
 
     return c;
 }
+
+static gmx_inline real
+gmx_simd4_reduce_r(gmx_simd4_real_t a)
+{
+    return a.r[0] + a.r[1] + a.r[2] + a.r[3];
+}
+
 #endif
 
 
@@ -137,7 +147,7 @@ gmx_set1_hpr(gmx_mm_hpr *a, real b)
 
 /* Load one real at b and one real at b+1 into halves of a, respectively */
 static gmx_inline void
-gmx_load1p1_pr(gmx_simd_ref_pr *a, const real *b)
+gmx_load1p1_pr(gmx_simd_real_t *a, const real *b)
 {
     int i;
 
@@ -150,7 +160,7 @@ gmx_load1p1_pr(gmx_simd_ref_pr *a, const real *b)
 
 /* Load reals at half-width aligned pointer b into two halves of a */
 static gmx_inline void
-gmx_loaddh_pr(gmx_simd_ref_pr *a, const real *b)
+gmx_loaddh_pr(gmx_simd_real_t *a, const real *b)
 {
     int i;
 
@@ -203,7 +213,7 @@ gmx_sub_hpr(gmx_mm_hpr a, gmx_mm_hpr b)
 
 /* Sum over 4 half SIMD registers */
 static gmx_inline gmx_mm_hpr
-gmx_sum4_hpr(gmx_simd_ref_pr a, gmx_simd_ref_pr b)
+gmx_sum4_hpr(gmx_simd_real_t a, gmx_simd_real_t b)
 {
     gmx_mm_hpr c;
     int        i;
@@ -222,11 +232,11 @@ gmx_sum4_hpr(gmx_simd_ref_pr a, gmx_simd_ref_pr b)
 
 #ifdef GMX_NBNXN_SIMD_2XNN
 /* Sum the elements of halfs of each input register and store sums in out */
-static gmx_inline gmx_mm_pr4
-gmx_mm_transpose_sum4h_pr(gmx_simd_ref_pr a, gmx_simd_ref_pr b)
+static gmx_inline gmx_simd4_real_t
+gmx_mm_transpose_sum4h_pr(gmx_simd_real_t a, gmx_simd_real_t b)
 {
-    gmx_mm_pr4 sum;
-    int        i;
+    gmx_simd4_real_t sum;
+    int              i;
 
     sum.r[0] = 0;
     sum.r[1] = 0;
@@ -246,7 +256,7 @@ gmx_mm_transpose_sum4h_pr(gmx_simd_ref_pr a, gmx_simd_ref_pr b)
 #endif
 
 static gmx_inline void
-gmx_pr_to_2hpr(gmx_simd_ref_pr a, gmx_mm_hpr *b, gmx_mm_hpr *c)
+gmx_pr_to_2hpr(gmx_simd_real_t a, gmx_mm_hpr *b, gmx_mm_hpr *c)
 {
     int i;
 
@@ -257,7 +267,7 @@ gmx_pr_to_2hpr(gmx_simd_ref_pr a, gmx_mm_hpr *b, gmx_mm_hpr *c)
     }
 }
 static gmx_inline void
-gmx_2hpr_to_pr(gmx_mm_hpr a, gmx_mm_hpr b, gmx_simd_ref_pr *c)
+gmx_2hpr_to_pr(gmx_mm_hpr a, gmx_mm_hpr b, gmx_simd_real_t *c)
 {
     int i;
 
@@ -273,16 +283,16 @@ gmx_2hpr_to_pr(gmx_mm_hpr a, gmx_mm_hpr b, gmx_simd_ref_pr *c)
 
 #ifndef TAB_FDV0
 static gmx_inline void
-load_table_f(const real *tab_coul_F, gmx_simd_ref_epi32 ti_S,
+load_table_f(const real *tab_coul_F, gmx_simd_int32_t ti_S,
              int gmx_unused *ti,
-             gmx_simd_ref_pr *ctab0_S, gmx_simd_ref_pr *ctab1_S)
+             gmx_simd_real_t *ctab0_S, gmx_simd_real_t *ctab1_S)
 {
     int i;
 
     for (i = 0; i < GMX_SIMD_REAL_WIDTH; i++)
     {
-        ctab0_S->r[i] = tab_coul_F[ti_S.r[i]];
-        ctab1_S->r[i] = tab_coul_F[ti_S.r[i]+1];
+        ctab0_S->r[i] = tab_coul_F[ti_S.i[i]];
+        ctab1_S->r[i] = tab_coul_F[ti_S.i[i]+1];
     }
 
     *ctab1_S  = gmx_simd_sub_r(*ctab1_S, *ctab0_S);
@@ -290,9 +300,9 @@ load_table_f(const real *tab_coul_F, gmx_simd_ref_epi32 ti_S,
 
 static gmx_inline void
 load_table_f_v(const real *tab_coul_F, const real *tab_coul_V,
-               gmx_simd_ref_epi32 ti_S, int *ti,
-               gmx_simd_ref_pr *ctab0_S, gmx_simd_ref_pr *ctab1_S,
-               gmx_simd_ref_pr *ctabv_S)
+               gmx_simd_int32_t ti_S, int *ti,
+               gmx_simd_real_t *ctab0_S, gmx_simd_real_t *ctab1_S,
+               gmx_simd_real_t *ctabv_S)
 {
     int i;
 
@@ -300,30 +310,30 @@ load_table_f_v(const real *tab_coul_F, const real *tab_coul_V,
 
     for (i = 0; i < GMX_SIMD_REAL_WIDTH; i++)
     {
-        ctabv_S->r[i] = tab_coul_V[ti_S.r[i]];
+        ctabv_S->r[i] = tab_coul_V[ti_S.i[i]];
     }
 }
 #endif
 
 #ifdef TAB_FDV0
 static gmx_inline void
-load_table_f(const real *tab_coul_FDV0, gmx_simd_ref_epi32 ti_S, int *ti,
-             gmx_simd_ref_pr *ctab0_S, gmx_simd_ref_pr *ctab1_S)
+load_table_f(const real *tab_coul_FDV0, gmx_simd_int32_t ti_S, int *ti,
+             gmx_simd_real_t *ctab0_S, gmx_simd_real_t *ctab1_S)
 {
     int i;
 
     for (i = 0; i < GMX_SIMD_REAL_WIDTH; i++)
     {
-        ctab0_S->r[i] = tab_coul_FDV0[ti_S.r[i]*4];
-        ctab1_S->r[i] = tab_coul_FDV0[ti_S.r[i]*4+1];
+        ctab0_S->r[i] = tab_coul_FDV0[ti_S.i[i]*4];
+        ctab1_S->r[i] = tab_coul_FDV0[ti_S.i[i]*4+1];
     }
 }
 
 static gmx_inline void
 load_table_f_v(const real *tab_coul_FDV0,
-               gmx_simd_ref_epi32 ti_S, int *ti,
-               gmx_simd_ref_pr *ctab0_S, gmx_simd_ref_pr *ctab1_S,
-               gmx_simd_ref_pr *ctabv_S)
+               gmx_simd_int32_t ti_S, int *ti,
+               gmx_simd_real_t *ctab0_S, gmx_simd_real_t *ctab1_S,
+               gmx_simd_real_t *ctabv_S)
 {
     int i;
 
@@ -331,7 +341,7 @@ load_table_f_v(const real *tab_coul_FDV0,
 
     for (i = 0; i < GMX_SIMD_REAL_WIDTH; i++)
     {
-        ctabv_S->r[i] = tab_coul_FDV0[ti_S.r[i]*4+2];
+        ctabv_S->r[i] = tab_coul_FDV0[ti_S.i[i]*4+2];
     }
 }
 #endif
@@ -340,10 +350,10 @@ load_table_f_v(const real *tab_coul_FDV0,
  * Note that 4/8-way SIMD requires gmx_mm_transpose_sum4_pr instead.
  */
 #if GMX_SIMD_REAL_WIDTH == 2
-static gmx_inline gmx_simd_ref_pr
-gmx_mm_transpose_sum2_pr(gmx_simd_ref_pr in0, gmx_simd_ref_pr in1)
+static gmx_inline gmx_simd_real_t
+gmx_mm_transpose_sum2_pr(gmx_simd_real_t in0, gmx_simd_real_t in1)
 {
-    gmx_simd_ref_pr sum;
+    gmx_simd_real_t sum;
 
     sum.r[0] = in0.r[0] + in0.r[1];
     sum.r[1] = in1.r[0] + in1.r[1];
@@ -354,19 +364,19 @@ gmx_mm_transpose_sum2_pr(gmx_simd_ref_pr in0, gmx_simd_ref_pr in1)
 
 #if GMX_SIMD_REAL_WIDTH >= 4
 #if GMX_SIMD_REAL_WIDTH == 4
-static gmx_inline gmx_simd_ref_pr
+static gmx_inline gmx_simd_real_t
 #else
-static gmx_inline gmx_mm_pr4
+static gmx_inline gmx_simd4_real_t
 #endif
-gmx_mm_transpose_sum4_pr(gmx_simd_ref_pr in0, gmx_simd_ref_pr in1,
-                         gmx_simd_ref_pr in2, gmx_simd_ref_pr in3)
+gmx_mm_transpose_sum4_pr(gmx_simd_real_t in0, gmx_simd_real_t in1,
+                         gmx_simd_real_t in2, gmx_simd_real_t in3)
 {
 #if GMX_SIMD_REAL_WIDTH == 4
-    gmx_simd_ref_pr sum;
+    gmx_simd_real_t  sum;
 #else
-    gmx_mm_pr4      sum;
+    gmx_simd4_real_t sum;
 #endif
-    int             i;
+    int              i;
 
     sum.r[0] = 0;
     sum.r[1] = 0;
@@ -392,8 +402,8 @@ gmx_mm_transpose_sum4_pr(gmx_simd_ref_pr in0, gmx_simd_ref_pr in1,
  * For this reference code we just use a plain-C sqrt.
  */
 static gmx_inline void
-gmx_mm_invsqrt2_pd(gmx_simd_ref_pr in0, gmx_simd_ref_pr in1,
-                   gmx_simd_ref_pr *out0, gmx_simd_ref_pr *out1)
+gmx_mm_invsqrt2_pd(gmx_simd_real_t in0, gmx_simd_real_t in1,
+                   gmx_simd_real_t *out0, gmx_simd_real_t *out1)
 {
     *out0 = gmx_simd_invsqrt_r(in0);
     *out1 = gmx_simd_invsqrt_r(in1);
@@ -402,7 +412,7 @@ gmx_mm_invsqrt2_pd(gmx_simd_ref_pr in0, gmx_simd_ref_pr in1,
 
 static gmx_inline void
 load_lj_pair_params(const real *nbfp, const int *type, int aj,
-                    gmx_simd_ref_pr *c6_S, gmx_simd_ref_pr *c12_S)
+                    gmx_simd_real_t *c6_S, gmx_simd_real_t *c12_S)
 {
     int i;
 
@@ -417,7 +427,7 @@ load_lj_pair_params(const real *nbfp, const int *type, int aj,
 static gmx_inline void
 load_lj_pair_params2(const real *nbfp0, const real *nbfp1,
                      const int *type, int aj,
-                     gmx_simd_ref_pr *c6_S, gmx_simd_ref_pr *c12_S)
+                     gmx_simd_real_t *c6_S, gmx_simd_real_t *c12_S)
 {
     int i;
 
@@ -445,9 +455,9 @@ gmx_simd_ref_load1_exclfilter(int src)
     gmx_simd_ref_exclfilter a;
     int                     i;
 
-    for (i = 0; i < GMX_SIMD_REF_WIDTH; i++)
+    for (i = 0; i < GMX_SIMD_REAL_WIDTH; i++)
     {
-        a.r[i] = src;
+        a.i[i] = src;
     }
 
     return a;
@@ -459,9 +469,9 @@ gmx_simd_ref_load_exclusion_filter(const int *src)
     gmx_simd_ref_exclfilter a;
     int                     i;
 
-    for (i = 0; i < GMX_SIMD_REF_WIDTH; i++)
+    for (i = 0; i < GMX_SIMD_REAL_WIDTH; i++)
     {
-        a.r[i] = src[i];
+        a.i[i] = src[i];
     }
 
     return a;
@@ -478,15 +488,15 @@ gmx_simd_ref_load_exclusion_filter(const int *src)
  * If the same bit is set in both input masks, return TRUE, else
  * FALSE. This function is only called with a single bit set in b.
  */
-static gmx_inline gmx_simd_ref_pb
+static gmx_inline gmx_simd_bool_t
 gmx_simd_ref_checkbitmask_pb(gmx_simd_ref_exclfilter a, gmx_simd_ref_exclfilter b)
 {
-    gmx_simd_ref_pb c;
+    gmx_simd_bool_t c;
     int             i;
 
-    for (i = 0; i < GMX_SIMD_REF_WIDTH; i++)
+    for (i = 0; i < GMX_SIMD_REAL_WIDTH; i++)
     {
-        c.r[i] = ((a.r[i] & b.r[i]) != 0);
+        c.b[i] = ((a.i[i] & b.i[i]) != 0);
     }
 
     return c;