AVX512 transposeScatterIncr/DecrU with load/store
authorRoland Schulz <roland.schulz@intel.com>
Wed, 9 Mar 2016 02:21:35 +0000 (18:21 -0800)
committerMark Abraham <mark.j.abraham@gmail.com>
Wed, 6 Apr 2016 13:52:30 +0000 (15:52 +0200)
Also remove one more mask

Change-Id: I0d27eb42eead92d2f50725b2f3831af7e9ee229e

src/gromacs/simd/impl_x86_avx_512/impl_x86_avx_512_util_double.h
src/gromacs/simd/impl_x86_avx_512/impl_x86_avx_512_util_float.h

index a0efbb83d46778a31596c64b666fb8f33e4e13cc..df12265797f530b9b3d5c72a451e9849f3320a5e 100644 (file)
@@ -240,19 +240,49 @@ transposeScatterIncrU(double *            base,
                       SimdDouble          v1,
                       SimdDouble          v2)
 {
-    GMX_ALIGNED(double, GMX_SIMD_DOUBLE_WIDTH)  rdata0[GMX_SIMD_DOUBLE_WIDTH];
-    GMX_ALIGNED(double, GMX_SIMD_DOUBLE_WIDTH)  rdata1[GMX_SIMD_DOUBLE_WIDTH];
-    GMX_ALIGNED(double, GMX_SIMD_DOUBLE_WIDTH)  rdata2[GMX_SIMD_DOUBLE_WIDTH];
-
-    store(rdata0, v0);
-    store(rdata1, v1);
-    store(rdata2, v2);
-
-    for (int i = 0; i < GMX_SIMD_DOUBLE_WIDTH; i++)
+    __m512d t[4], t5, t6, t7, t8;
+    GMX_ALIGNED(std::int64_t, 8)    o[8];
+    _mm512_store_epi64(o, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(_mm256_load_si256((const __m256i*)(offset  )), _mm256_set1_epi32(align))));
+    t5   = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
+    t6   = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
+    t7   = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
+    t8   = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
+    t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
+    t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
+    t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
+    t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
+    if (align < 4)
+    {
+        for (int i = 0; i < 4; i++)
+        {
+            _mm512_mask_storeu_pd(base + o[0 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
+                                          _mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i]))));
+            _mm512_mask_storeu_pd(base + o[4 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
+                                          _mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1))));
+        }
+    }
+    else
     {
-        base[ align * offset[i] + 0] += rdata0[i];
-        base[ align * offset[i] + 1] += rdata1[i];
-        base[ align * offset[i] + 2] += rdata2[i];
+        if (align % 4 == 0)
+        {
+            for (int i = 0; i < 4; i++)
+            {
+                _mm256_store_pd(base + o[0 + i],
+                                _mm256_add_pd(_mm256_load_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
+                _mm256_store_pd(base + o[4 + i],
+                                _mm256_add_pd(_mm256_load_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
+            }
+        }
+        else
+        {
+            for (int i = 0; i < 4; i++)
+            {
+                _mm256_storeu_pd(base + o[0 + i],
+                                 _mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
+                _mm256_storeu_pd(base + o[4 + i],
+                                 _mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
+            }
+        }
     }
 }
 
@@ -264,19 +294,49 @@ transposeScatterDecrU(double *            base,
                       SimdDouble          v1,
                       SimdDouble          v2)
 {
-    GMX_ALIGNED(double, GMX_SIMD_DOUBLE_WIDTH)  rdata0[GMX_SIMD_DOUBLE_WIDTH];
-    GMX_ALIGNED(double, GMX_SIMD_DOUBLE_WIDTH)  rdata1[GMX_SIMD_DOUBLE_WIDTH];
-    GMX_ALIGNED(double, GMX_SIMD_DOUBLE_WIDTH)  rdata2[GMX_SIMD_DOUBLE_WIDTH];
-
-    store(rdata0, v0);
-    store(rdata1, v1);
-    store(rdata2, v2);
-
-    for (int i = 0; i < GMX_SIMD_DOUBLE_WIDTH; i++)
+    __m512d t[4], t5, t6, t7, t8;
+    GMX_ALIGNED(std::int64_t, 8)    o[8];
+    _mm512_store_epi64(o, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(_mm256_load_si256((const __m256i*)(offset  )), _mm256_set1_epi32(align))));
+    t5   = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
+    t6   = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
+    t7   = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
+    t8   = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
+    t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
+    t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
+    t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
+    t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
+    if (align < 4)
+    {
+        for (int i = 0; i < 4; i++)
+        {
+            _mm512_mask_storeu_pd(base + o[0 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
+                                          _mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i]))));
+            _mm512_mask_storeu_pd(base + o[4 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
+                                          _mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1))));
+        }
+    }
+    else
     {
-        base[ align * offset[i] + 0] -= rdata0[i];
-        base[ align * offset[i] + 1] -= rdata1[i];
-        base[ align * offset[i] + 2] -= rdata2[i];
+        if (align % 4 == 0)
+        {
+            for (int i = 0; i < 4; i++)
+            {
+                _mm256_store_pd(base + o[0 + i],
+                                _mm256_sub_pd(_mm256_load_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
+                _mm256_store_pd(base + o[4 + i],
+                                _mm256_sub_pd(_mm256_load_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
+            }
+        }
+        else
+        {
+            for (int i = 0; i < 4; i++)
+            {
+                _mm256_storeu_pd(base + o[0 + i],
+                                 _mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
+                _mm256_storeu_pd(base + o[4 + i],
+                                 _mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
+            }
+        }
     }
 }
 
@@ -369,7 +429,7 @@ storeDualHsimd(double *     m0,
     assert(std::size_t(m1) % 32 == 0);
 
     _mm256_store_pd(m0, _mm512_castpd512_pd256(a.simdInternal_));
-    _mm512_mask_storeu_pd(m1-4, avx512Int2Mask(0xF0), a.simdInternal_);
+    _mm256_store_pd(m1, _mm512_extractf64x4_pd(a.simdInternal_, 1));
 }
 
 static inline void gmx_simdcall
index ee94c3f9bcbeaba9995a4bc14b89da50567e9072..e3acc0f37f235717eaa17e881843e9dc337ac4ed 100644 (file)
@@ -242,19 +242,67 @@ transposeScatterIncrU(float *              base,
                       SimdFloat            v1,
                       SimdFloat            v2)
 {
-    GMX_ALIGNED(float, GMX_SIMD_FLOAT_WIDTH)  rdata0[GMX_SIMD_FLOAT_WIDTH];
-    GMX_ALIGNED(float, GMX_SIMD_FLOAT_WIDTH)  rdata1[GMX_SIMD_FLOAT_WIDTH];
-    GMX_ALIGNED(float, GMX_SIMD_FLOAT_WIDTH)  rdata2[GMX_SIMD_FLOAT_WIDTH];
-
-    store(rdata0, v0);
-    store(rdata1, v1);
-    store(rdata2, v2);
-
-    for (int i = 0; i < GMX_SIMD_FLOAT_WIDTH; i++)
+    __m512 t[4], t5, t6, t7, t8;
+    int    i;
+    GMX_ALIGNED(std::int32_t, 8)    o[16];
+    _mm512_store_epi32(o, _mm512_mullo_epi32(_mm512_load_epi32(offset), _mm512_set1_epi32(align)));
+    if (align < 4)
+    {
+        t5   = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
+        t6   = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
+        t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
+        t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
+        t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
+        t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
+        for (i = 0; i < 4; i++)
+        {
+            _mm512_mask_storeu_ps(base + o[i], avx512Int2Mask(7), _mm512_castps128_ps512(
+                                          _mm_add_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i]))));
+            _mm512_mask_storeu_ps(base + o[ 4 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
+                                          _mm_add_ps(_mm_loadu_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1))));
+            _mm512_mask_storeu_ps(base + o[ 8 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
+                                          _mm_add_ps(_mm_loadu_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2))));
+            _mm512_mask_storeu_ps(base + o[12 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
+                                          _mm_add_ps(_mm_loadu_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3))));
+        }
+    }
+    else
     {
-        base[ align * offset[i] + 0] += rdata0[i];
-        base[ align * offset[i] + 1] += rdata1[i];
-        base[ align * offset[i] + 2] += rdata2[i];
+        //One could use shuffle here too if it is OK to overwrite the padded elements for alignment
+        t5    = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
+        t6    = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
+        t7    = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
+        t8    = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
+        t[0]  = _mm512_unpacklo_ps(t5, t7);                             // x0 y0 z0  0 | x4 y4 z4 0
+        t[1]  = _mm512_unpackhi_ps(t5, t7);                             // x1 y1 z1  0 | x5 y5 z5 0
+        t[2]  = _mm512_unpacklo_ps(t6, t8);                             // x2 y2 z2  0 | x6 y6 z6 0
+        t[3]  = _mm512_unpackhi_ps(t6, t8);                             // x3 y3 z3  0 | x7 y7 z7 0
+        if (align % 4 == 0)
+        {
+            for (i = 0; i < 4; i++)
+            {
+                _mm_store_ps(base + o[i], _mm_add_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
+                _mm_store_ps(base + o[ 4 + i],
+                             _mm_add_ps(_mm_load_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
+                _mm_store_ps(base + o[ 8 + i],
+                             _mm_add_ps(_mm_load_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
+                _mm_store_ps(base + o[12 + i],
+                             _mm_add_ps(_mm_load_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
+            }
+        }
+        else
+        {
+            for (i = 0; i < 4; i++)
+            {
+                _mm_storeu_ps(base + o[i], _mm_add_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
+                _mm_storeu_ps(base + o[ 4 + i],
+                              _mm_add_ps(_mm_loadu_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
+                _mm_storeu_ps(base + o[ 8 + i],
+                              _mm_add_ps(_mm_loadu_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
+                _mm_storeu_ps(base + o[12 + i],
+                              _mm_add_ps(_mm_loadu_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
+            }
+        }
     }
 }
 
@@ -266,19 +314,67 @@ transposeScatterDecrU(float *              base,
                       SimdFloat            v1,
                       SimdFloat            v2)
 {
-    GMX_ALIGNED(float, GMX_SIMD_FLOAT_WIDTH)  rdata0[GMX_SIMD_FLOAT_WIDTH];
-    GMX_ALIGNED(float, GMX_SIMD_FLOAT_WIDTH)  rdata1[GMX_SIMD_FLOAT_WIDTH];
-    GMX_ALIGNED(float, GMX_SIMD_FLOAT_WIDTH)  rdata2[GMX_SIMD_FLOAT_WIDTH];
-
-    store(rdata0, v0);
-    store(rdata1, v1);
-    store(rdata2, v2);
-
-    for (int i = 0; i < GMX_SIMD_FLOAT_WIDTH; i++)
+    __m512 t[4], t5, t6, t7, t8;
+    int    i;
+    GMX_ALIGNED(std::int32_t, 8)    o[16];
+    _mm512_store_epi32(o, _mm512_mullo_epi32(_mm512_load_epi32(offset), _mm512_set1_epi32(align)));
+    if (align < 4)
+    {
+        t5   = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
+        t6   = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
+        t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
+        t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
+        t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
+        t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
+        for (i = 0; i < 4; i++)
+        {
+            _mm512_mask_storeu_ps(base + o[i], avx512Int2Mask(7), _mm512_castps128_ps512(
+                                          _mm_sub_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i]))));
+            _mm512_mask_storeu_ps(base + o[ 4 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
+                                          _mm_sub_ps(_mm_loadu_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1))));
+            _mm512_mask_storeu_ps(base + o[ 8 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
+                                          _mm_sub_ps(_mm_loadu_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2))));
+            _mm512_mask_storeu_ps(base + o[12 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
+                                          _mm_sub_ps(_mm_loadu_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3))));
+        }
+    }
+    else
     {
-        base[ align * offset[i] + 0] -= rdata0[i];
-        base[ align * offset[i] + 1] -= rdata1[i];
-        base[ align * offset[i] + 2] -= rdata2[i];
+        //One could use shuffle here too if it is OK to overwrite the padded elements for alignment
+        t5    = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
+        t6    = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
+        t7    = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
+        t8    = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
+        t[0]  = _mm512_unpacklo_ps(t5, t7);                             // x0 y0 z0  0 | x4 y4 z4 0
+        t[1]  = _mm512_unpackhi_ps(t5, t7);                             // x1 y1 z1  0 | x5 y5 z5 0
+        t[2]  = _mm512_unpacklo_ps(t6, t8);                             // x2 y2 z2  0 | x6 y6 z6 0
+        t[3]  = _mm512_unpackhi_ps(t6, t8);                             // x3 y3 z3  0 | x7 y7 z7 0
+        if (align % 4 == 0)
+        {
+            for (i = 0; i < 4; i++)
+            {
+                _mm_store_ps(base + o[i], _mm_sub_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
+                _mm_store_ps(base + o[ 4 + i],
+                             _mm_sub_ps(_mm_load_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
+                _mm_store_ps(base + o[ 8 + i],
+                             _mm_sub_ps(_mm_load_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
+                _mm_store_ps(base + o[12 + i],
+                             _mm_sub_ps(_mm_load_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
+            }
+        }
+        else
+        {
+            for (i = 0; i < 4; i++)
+            {
+                _mm_storeu_ps(base + o[i], _mm_sub_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
+                _mm_storeu_ps(base + o[ 4 + i],
+                              _mm_sub_ps(_mm_loadu_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
+                _mm_storeu_ps(base + o[ 8 + i],
+                              _mm_sub_ps(_mm_loadu_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
+                _mm_storeu_ps(base + o[12 + i],
+                              _mm_sub_ps(_mm_loadu_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
+            }
+        }
     }
 }
 
@@ -372,7 +468,7 @@ storeDualHsimd(float *     m0,
     assert(std::size_t(m1) % 32 == 0);
 
     _mm256_store_ps(m0, _mm512_castps512_ps256(a.simdInternal_));
-    _mm512_mask_storeu_ps(m1-8, avx512Int2Mask(0xFF00), a.simdInternal_);
+    _mm256_store_pd(reinterpret_cast<double*>(m1), _mm512_extractf64x4_pd(_mm512_castps_pd(a.simdInternal_), 1));
 }
 
 static inline void gmx_simdcall