2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017,2018, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 #ifndef GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
37 #define GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
44 #include <immintrin.h>
46 #include "gromacs/utility/basedefinitions.h"
48 #include "impl_x86_avx_512_general.h"
49 #include "impl_x86_avx_512_simd_double.h"
54 static const int c_simdBestPairAlignmentDouble = 2;
58 // Multiply function optimized for powers of 2, for which it is done by
59 // shifting. Currently up to 8 is accelerated. Could be accelerated for any
60 // number with a constexpr log2 function.
62 SimdDInt32 fastMultiply(SimdDInt32 x)
66 return _mm256_slli_epi32(x.simdInternal_, 1);
70 return _mm256_slli_epi32(x.simdInternal_, 2);
74 return _mm256_slli_epi32(x.simdInternal_, 3);
83 static inline void gmx_simdcall
84 gatherLoadBySimdIntTranspose(const double *, SimdDInt32)
86 //Nothing to do. Termination of recursion.
91 template <int align, typename ... Targs>
92 static inline void gmx_simdcall
93 gatherLoadBySimdIntTranspose(const double * base, SimdDInt32 offset, SimdDouble *v, Targs... Fargs)
97 offset = fastMultiply<align>(offset);
99 constexpr size_t scale = sizeof(double);
100 v->simdInternal_ = _mm512_i32gather_pd(offset.simdInternal_, base, scale);
101 gatherLoadBySimdIntTranspose<1>(base+1, offset, Fargs ...);
104 template <int align, typename ... Targs>
105 static inline void gmx_simdcall
106 gatherLoadUBySimdIntTranspose(const double *base, SimdDInt32 offset, Targs... Fargs)
108 gatherLoadBySimdIntTranspose<align>(base, offset, Fargs ...);
111 template <int align, typename ... Targs>
112 static inline void gmx_simdcall
113 gatherLoadTranspose(const double *base, const std::int32_t offset[], Targs... Fargs)
115 gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdDInt32Tag()), Fargs ...);
118 template <int align, typename ... Targs>
119 static inline void gmx_simdcall
120 gatherLoadUTranspose(const double *base, const std::int32_t offset[], Targs... Fargs)
122 gatherLoadTranspose<align>(base, offset, Fargs ...);
126 static inline void gmx_simdcall
127 transposeScatterStoreU(double * base,
128 const std::int32_t offset[],
133 SimdDInt32 simdoffset = simdLoad(offset, SimdDInt32Tag());
137 simdoffset = fastMultiply<align>(simdoffset);;
139 constexpr size_t scale = sizeof(double);
140 _mm512_i32scatter_pd(base, simdoffset.simdInternal_, v0.simdInternal_, scale);
141 _mm512_i32scatter_pd(&(base[1]), simdoffset.simdInternal_, v1.simdInternal_, scale);
142 _mm512_i32scatter_pd(&(base[2]), simdoffset.simdInternal_, v2.simdInternal_, scale);
146 static inline void gmx_simdcall
147 transposeScatterIncrU(double * base,
148 const std::int32_t offset[],
153 __m512d t[4], t5, t6, t7, t8;
154 alignas(GMX_SIMD_ALIGNMENT) std::int64_t o[8];
155 //TODO: should use fastMultiply
156 _mm512_store_epi64(o, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(_mm256_load_si256((const __m256i*)(offset )), _mm256_set1_epi32(align))));
157 t5 = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
158 t6 = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
159 t7 = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
160 t8 = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
161 t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
162 t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
163 t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
164 t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
167 for (int i = 0; i < 4; i++)
169 _mm512_mask_storeu_pd(base + o[0 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
170 _mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i]))));
171 _mm512_mask_storeu_pd(base + o[4 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
172 _mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1))));
179 for (int i = 0; i < 4; i++)
181 _mm256_store_pd(base + o[0 + i],
182 _mm256_add_pd(_mm256_load_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
183 _mm256_store_pd(base + o[4 + i],
184 _mm256_add_pd(_mm256_load_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
189 for (int i = 0; i < 4; i++)
191 _mm256_storeu_pd(base + o[0 + i],
192 _mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
193 _mm256_storeu_pd(base + o[4 + i],
194 _mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
201 static inline void gmx_simdcall
202 transposeScatterDecrU(double * base,
203 const std::int32_t offset[],
208 __m512d t[4], t5, t6, t7, t8;
209 alignas(GMX_SIMD_ALIGNMENT) std::int64_t o[8];
210 //TODO: should use fastMultiply
211 _mm512_store_epi64(o, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(_mm256_load_si256((const __m256i*)(offset )), _mm256_set1_epi32(align))));
212 t5 = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
213 t6 = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
214 t7 = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
215 t8 = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
216 t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
217 t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
218 t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
219 t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
222 for (int i = 0; i < 4; i++)
224 _mm512_mask_storeu_pd(base + o[0 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
225 _mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i]))));
226 _mm512_mask_storeu_pd(base + o[4 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
227 _mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1))));
234 for (int i = 0; i < 4; i++)
236 _mm256_store_pd(base + o[0 + i],
237 _mm256_sub_pd(_mm256_load_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
238 _mm256_store_pd(base + o[4 + i],
239 _mm256_sub_pd(_mm256_load_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
244 for (int i = 0; i < 4; i++)
246 _mm256_storeu_pd(base + o[0 + i],
247 _mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
248 _mm256_storeu_pd(base + o[4 + i],
249 _mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
255 static inline void gmx_simdcall
256 expandScalarsToTriplets(SimdDouble scalar,
257 SimdDouble * triplets0,
258 SimdDouble * triplets1,
259 SimdDouble * triplets2)
261 triplets0->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(_mm512_set_epi32(5, 4, 5, 4, 3, 2, 3, 2, 3, 2, 1, 0, 1, 0, 1, 0),
262 _mm512_castpd_si512(scalar.simdInternal_)));
263 triplets1->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(_mm512_set_epi32(11, 10, 9, 8, 9, 8, 9, 8, 7, 6, 7, 6, 7, 6, 5, 4),
264 _mm512_castpd_si512(scalar.simdInternal_)));
265 triplets2->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(_mm512_set_epi32(15, 14, 15, 14, 15, 14, 13, 12, 13, 12, 13, 12, 11, 10, 11, 10),
266 _mm512_castpd_si512(scalar.simdInternal_)));
270 static inline double gmx_simdcall
271 reduceIncr4ReturnSum(double * m,
280 assert(std::size_t(m) % 32 == 0);
282 t0 = _mm512_add_pd(v0.simdInternal_, _mm512_permute_pd(v0.simdInternal_, 0x55));
283 t2 = _mm512_add_pd(v2.simdInternal_, _mm512_permute_pd(v2.simdInternal_, 0x55));
284 t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xAA), v1.simdInternal_, _mm512_permute_pd(v1.simdInternal_, 0x55));
285 t2 = _mm512_mask_add_pd(t2, avx512Int2Mask(0xAA), v3.simdInternal_, _mm512_permute_pd(v3.simdInternal_, 0x55));
286 t0 = _mm512_add_pd(t0, _mm512_shuffle_f64x2(t0, t0, 0x4E));
287 t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xF0), t2, _mm512_shuffle_f64x2(t2, t2, 0x4E));
288 t0 = _mm512_add_pd(t0, _mm512_shuffle_f64x2(t0, t0, 0xB1));
289 t0 = _mm512_mask_shuffle_f64x2(t0, avx512Int2Mask(0x0C), t0, t0, 0xEE);
291 t3 = _mm512_castpd512_pd256(t0);
292 t4 = _mm256_load_pd(m);
293 t4 = _mm256_add_pd(t4, t3);
294 _mm256_store_pd(m, t4);
296 t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0x4E));
297 t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
299 return _mm_cvtsd_f64(_mm512_castpd512_pd128(t0));
302 static inline SimdDouble gmx_simdcall
303 loadDualHsimd(const double * m0,
306 assert(std::size_t(m0) % 32 == 0);
307 assert(std::size_t(m1) % 32 == 0);
310 _mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_load_pd(m0)),
311 _mm256_load_pd(m1), 1)
315 static inline SimdDouble gmx_simdcall
316 loadDuplicateHsimd(const double * m)
318 assert(std::size_t(m) % 32 == 0);
321 _mm512_broadcast_f64x4(_mm256_load_pd(m))
325 static inline SimdDouble gmx_simdcall
326 loadU1DualHsimd(const double * m)
329 _mm512_insertf64x4(_mm512_broadcastsd_pd(_mm_load_sd(m)),
330 _mm256_broadcastsd_pd(_mm_load_sd(m+1)), 1)
335 static inline void gmx_simdcall
336 storeDualHsimd(double * m0,
340 assert(std::size_t(m0) % 32 == 0);
341 assert(std::size_t(m1) % 32 == 0);
343 _mm256_store_pd(m0, _mm512_castpd512_pd256(a.simdInternal_));
344 _mm256_store_pd(m1, _mm512_extractf64x4_pd(a.simdInternal_, 1));
347 static inline void gmx_simdcall
348 incrDualHsimd(double * m0,
352 assert(std::size_t(m0) % 32 == 0);
353 assert(std::size_t(m1) % 32 == 0);
358 x = _mm256_load_pd(m0);
359 x = _mm256_add_pd(x, _mm512_castpd512_pd256(a.simdInternal_));
360 _mm256_store_pd(m0, x);
363 x = _mm256_load_pd(m1);
364 x = _mm256_add_pd(x, _mm512_extractf64x4_pd(a.simdInternal_, 1));
365 _mm256_store_pd(m1, x);
368 static inline void gmx_simdcall
369 decrHsimd(double * m,
374 assert(std::size_t(m) % 32 == 0);
376 a.simdInternal_ = _mm512_add_pd(a.simdInternal_, _mm512_shuffle_f64x2(a.simdInternal_, a.simdInternal_, 0xEE));
377 t = _mm256_load_pd(m);
378 t = _mm256_sub_pd(t, _mm512_castpd512_pd256(a.simdInternal_));
379 _mm256_store_pd(m, t);
384 static inline void gmx_simdcall
385 gatherLoadTransposeHsimd(const double * base0,
386 const double * base1,
387 const std::int32_t offset[],
395 assert(std::size_t(offset) % 16 == 0);
396 assert(std::size_t(base0) % 16 == 0);
397 assert(std::size_t(base1) % 16 == 0);
399 idx0 = _mm_load_si128(reinterpret_cast<const __m128i*>(offset));
401 static_assert(align == 2 || align == 4, "If more are needed use fastMultiply");
402 idx0 = _mm_slli_epi32(idx0, align == 2 ? 1 : 2);
404 idx1 = _mm_add_epi32(idx0, _mm_set1_epi32(1));
406 idx = _mm256_inserti128_si256(_mm256_castsi128_si256(idx0), idx1, 1);
408 constexpr size_t scale = sizeof(double);
409 tmp1 = _mm512_i32gather_pd(idx, base0, scale); //TODO: Might be faster to use invidual loads
410 tmp2 = _mm512_i32gather_pd(idx, base1, scale);
412 v0->simdInternal_ = _mm512_shuffle_f64x2(tmp1, tmp2, 0x44 );
413 v1->simdInternal_ = _mm512_shuffle_f64x2(tmp1, tmp2, 0xEE );
416 static inline double gmx_simdcall
417 reduceIncr4ReturnSumHsimd(double * m,
424 assert(std::size_t(m) % 32 == 0);
426 t0 = _mm512_add_pd(v0.simdInternal_, _mm512_permutex_pd(v0.simdInternal_, 0x4E));
427 t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xCC), v1.simdInternal_, _mm512_permutex_pd(v1.simdInternal_, 0x4E));
428 t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
429 t0 = _mm512_mask_shuffle_f64x2(t0, avx512Int2Mask(0xAA), t0, t0, 0xEE);
431 t2 = _mm512_castpd512_pd256(t0);
432 t3 = _mm256_load_pd(m);
433 t3 = _mm256_add_pd(t3, t2);
434 _mm256_store_pd(m, t3);
436 t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0x4E));
437 t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
439 return _mm_cvtsd_f64(_mm512_castpd512_pd128(t0));
442 static inline SimdDouble gmx_simdcall
443 loadU4NOffset(const double *m, int offset)
446 _mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_loadu_pd(m)),
447 _mm256_loadu_pd(m+offset), 1)
453 #endif // GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H