2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 #ifndef GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H
37 #define GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H
44 #include <immintrin.h>
46 #include "gromacs/utility/basedefinitions.h"
48 #include "impl_x86_avx_512_general.h"
49 #include "impl_x86_avx_512_simd_float.h"
54 static const int c_simdBestPairAlignmentFloat = 2;
58 // Multiply function optimized for powers of 2, for which it is done by
59 // shifting. Currently up to 8 is accelerated. Could be accelerated for any
60 // number with a constexpr log2 function.
62 SimdFInt32 fastMultiply(SimdFInt32 x)
66 return _mm512_slli_epi32(x.simdInternal_, 1);
70 return _mm512_slli_epi32(x.simdInternal_, 2);
74 return _mm512_slli_epi32(x.simdInternal_, 3);
83 static inline void gmx_simdcall
84 gatherLoadBySimdIntTranspose(const float *, SimdFInt32)
86 //Nothing to do. Termination of recursion.
90 template <int align, typename ... Targs>
91 static inline void gmx_simdcall
92 gatherLoadBySimdIntTranspose(const float *base, SimdFInt32 offset, SimdFloat *v, Targs... Fargs)
94 // For align 1 or 2: No multiplication of offset is needed
97 offset = fastMultiply<align>(offset);
99 // For align 2: Scale of 2*sizeof(float) is used (maximum supported scale)
100 constexpr int align_ = (align > 2) ? 1 : align;
101 v->simdInternal_ = _mm512_i32gather_ps(offset.simdInternal_, base, sizeof(float)*align_);
102 // Gather remaining elements. Avoid extra multiplication (new align is 1 or 2).
103 gatherLoadBySimdIntTranspose<align_>(base+1, offset, Fargs ...);
106 template <int align, typename ... Targs>
107 static inline void gmx_simdcall
108 gatherLoadUBySimdIntTranspose(const float *base, SimdFInt32 offset, Targs... Fargs)
110 gatherLoadBySimdIntTranspose<align>(base, offset, Fargs ...);
113 template <int align, typename ... Targs>
114 static inline void gmx_simdcall
115 gatherLoadTranspose(const float *base, const std::int32_t offset[], Targs... Fargs)
117 gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdFInt32Tag()), Fargs ...);
120 template <int align, typename ... Targs>
121 static inline void gmx_simdcall
122 gatherLoadUTranspose(const float *base, const std::int32_t offset[], Targs... Fargs)
124 gatherLoadTranspose<align>(base, offset, Fargs ...);
128 static inline void gmx_simdcall
129 transposeScatterStoreU(float * base,
130 const std::int32_t offset[],
135 SimdFInt32 simdoffset = simdLoad(offset, SimdFInt32Tag());
138 simdoffset = fastMultiply<align>(simdoffset);
140 constexpr size_t scale = (align > 2) ? sizeof(float) : sizeof(float) * align;
142 _mm512_i32scatter_ps(base, simdoffset.simdInternal_, v0.simdInternal_, scale);
143 _mm512_i32scatter_ps(&(base[1]), simdoffset.simdInternal_, v1.simdInternal_, scale);
144 _mm512_i32scatter_ps(&(base[2]), simdoffset.simdInternal_, v2.simdInternal_, scale);
148 static inline void gmx_simdcall
149 transposeScatterIncrU(float * base,
150 const std::int32_t offset[],
155 __m512 t[4], t5, t6, t7, t8;
157 alignas(GMX_SIMD_ALIGNMENT) std::int32_t o[16];
158 store(o, fastMultiply<align>(simdLoad(offset, SimdFInt32Tag())));
161 t5 = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
162 t6 = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
163 t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
164 t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
165 t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
166 t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
167 for (i = 0; i < 4; i++)
169 _mm512_mask_storeu_ps(base + o[i], avx512Int2Mask(7), _mm512_castps128_ps512(
170 _mm_add_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i]))));
171 _mm512_mask_storeu_ps(base + o[ 4 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
172 _mm_add_ps(_mm_loadu_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1))));
173 _mm512_mask_storeu_ps(base + o[ 8 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
174 _mm_add_ps(_mm_loadu_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2))));
175 _mm512_mask_storeu_ps(base + o[12 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
176 _mm_add_ps(_mm_loadu_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3))));
181 //One could use shuffle here too if it is OK to overwrite the padded elements for alignment
182 t5 = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
183 t6 = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
184 t7 = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
185 t8 = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
186 t[0] = _mm512_unpacklo_ps(t5, t7); // x0 y0 z0 0 | x4 y4 z4 0
187 t[1] = _mm512_unpackhi_ps(t5, t7); // x1 y1 z1 0 | x5 y5 z5 0
188 t[2] = _mm512_unpacklo_ps(t6, t8); // x2 y2 z2 0 | x6 y6 z6 0
189 t[3] = _mm512_unpackhi_ps(t6, t8); // x3 y3 z3 0 | x7 y7 z7 0
192 for (i = 0; i < 4; i++)
194 _mm_store_ps(base + o[i], _mm_add_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
195 _mm_store_ps(base + o[ 4 + i],
196 _mm_add_ps(_mm_load_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
197 _mm_store_ps(base + o[ 8 + i],
198 _mm_add_ps(_mm_load_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
199 _mm_store_ps(base + o[12 + i],
200 _mm_add_ps(_mm_load_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
205 for (i = 0; i < 4; i++)
207 _mm_storeu_ps(base + o[i], _mm_add_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
208 _mm_storeu_ps(base + o[ 4 + i],
209 _mm_add_ps(_mm_loadu_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
210 _mm_storeu_ps(base + o[ 8 + i],
211 _mm_add_ps(_mm_loadu_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
212 _mm_storeu_ps(base + o[12 + i],
213 _mm_add_ps(_mm_loadu_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
220 static inline void gmx_simdcall
221 transposeScatterDecrU(float * base,
222 const std::int32_t offset[],
227 __m512 t[4], t5, t6, t7, t8;
229 alignas(GMX_SIMD_ALIGNMENT) std::int32_t o[16];
230 store(o, fastMultiply<align>(simdLoad(offset, SimdFInt32Tag())));
233 t5 = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
234 t6 = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
235 t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
236 t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
237 t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
238 t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
239 for (i = 0; i < 4; i++)
241 _mm512_mask_storeu_ps(base + o[i], avx512Int2Mask(7), _mm512_castps128_ps512(
242 _mm_sub_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i]))));
243 _mm512_mask_storeu_ps(base + o[ 4 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
244 _mm_sub_ps(_mm_loadu_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1))));
245 _mm512_mask_storeu_ps(base + o[ 8 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
246 _mm_sub_ps(_mm_loadu_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2))));
247 _mm512_mask_storeu_ps(base + o[12 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
248 _mm_sub_ps(_mm_loadu_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3))));
253 //One could use shuffle here too if it is OK to overwrite the padded elements for alignment
254 t5 = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
255 t6 = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
256 t7 = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
257 t8 = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
258 t[0] = _mm512_unpacklo_ps(t5, t7); // x0 y0 z0 0 | x4 y4 z4 0
259 t[1] = _mm512_unpackhi_ps(t5, t7); // x1 y1 z1 0 | x5 y5 z5 0
260 t[2] = _mm512_unpacklo_ps(t6, t8); // x2 y2 z2 0 | x6 y6 z6 0
261 t[3] = _mm512_unpackhi_ps(t6, t8); // x3 y3 z3 0 | x7 y7 z7 0
264 for (i = 0; i < 4; i++)
266 _mm_store_ps(base + o[i], _mm_sub_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
267 _mm_store_ps(base + o[ 4 + i],
268 _mm_sub_ps(_mm_load_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
269 _mm_store_ps(base + o[ 8 + i],
270 _mm_sub_ps(_mm_load_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
271 _mm_store_ps(base + o[12 + i],
272 _mm_sub_ps(_mm_load_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
277 for (i = 0; i < 4; i++)
279 _mm_storeu_ps(base + o[i], _mm_sub_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
280 _mm_storeu_ps(base + o[ 4 + i],
281 _mm_sub_ps(_mm_loadu_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
282 _mm_storeu_ps(base + o[ 8 + i],
283 _mm_sub_ps(_mm_loadu_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
284 _mm_storeu_ps(base + o[12 + i],
285 _mm_sub_ps(_mm_loadu_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
291 static inline void gmx_simdcall
292 expandScalarsToTriplets(SimdFloat scalar,
293 SimdFloat * triplets0,
294 SimdFloat * triplets1,
295 SimdFloat * triplets2)
297 triplets0->simdInternal_ = _mm512_permutexvar_ps(_mm512_set_epi32(5, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1, 0, 0, 0),
298 scalar.simdInternal_);
299 triplets1->simdInternal_ = _mm512_permutexvar_ps(_mm512_set_epi32(10, 10, 9, 9, 9, 8, 8, 8, 7, 7, 7, 6, 6, 6, 5, 5),
300 scalar.simdInternal_);
301 triplets2->simdInternal_ = _mm512_permutexvar_ps(_mm512_set_epi32(15, 15, 15, 14, 14, 14, 13, 13, 13, 12, 12, 12, 11, 11, 11, 10),
302 scalar.simdInternal_);
306 static inline float gmx_simdcall
307 reduceIncr4ReturnSum(float * m,
316 assert(std::size_t(m) % 16 == 0);
318 t0 = _mm512_add_ps(v0.simdInternal_, _mm512_permute_ps(v0.simdInternal_, 0x4E));
319 t0 = _mm512_mask_add_ps(t0, avx512Int2Mask(0xCCCC), v2.simdInternal_, _mm512_permute_ps(v2.simdInternal_, 0x4E));
320 t1 = _mm512_add_ps(v1.simdInternal_, _mm512_permute_ps(v1.simdInternal_, 0x4E));
321 t1 = _mm512_mask_add_ps(t1, avx512Int2Mask(0xCCCC), v3.simdInternal_, _mm512_permute_ps(v3.simdInternal_, 0x4E));
322 t2 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0xB1));
323 t2 = _mm512_mask_add_ps(t2, avx512Int2Mask(0xAAAA), t1, _mm512_permute_ps(t1, 0xB1));
325 t2 = _mm512_add_ps(t2, _mm512_shuffle_f32x4(t2, t2, 0x4E));
326 t2 = _mm512_add_ps(t2, _mm512_shuffle_f32x4(t2, t2, 0xB1));
328 t3 = _mm512_castps512_ps128(t2);
330 t4 = _mm_add_ps(t4, t3);
333 t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0x4E));
334 t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0xB1));
336 return _mm_cvtss_f32(t3);
340 static inline SimdFloat gmx_simdcall
341 loadDualHsimd(const float * m0,
344 assert(std::size_t(m0) % 32 == 0);
345 assert(std::size_t(m1) % 32 == 0);
348 _mm512_castpd_ps(_mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_load_pd(reinterpret_cast<const double*>(m0))),
349 _mm256_load_pd(reinterpret_cast<const double*>(m1)), 1))
353 static inline SimdFloat gmx_simdcall
354 loadDuplicateHsimd(const float * m)
356 assert(std::size_t(m) % 32 == 0);
358 _mm512_castpd_ps(_mm512_broadcast_f64x4(_mm256_load_pd(reinterpret_cast<const double*>(m))))
362 static inline SimdFloat gmx_simdcall
363 loadU1DualHsimd(const float * m)
366 _mm512_shuffle_f32x4(_mm512_broadcastss_ps(_mm_load_ss(m)),
367 _mm512_broadcastss_ps(_mm_load_ss(m+1)), 0x44)
372 static inline void gmx_simdcall
373 storeDualHsimd(float * m0,
377 assert(std::size_t(m0) % 32 == 0);
378 assert(std::size_t(m1) % 32 == 0);
380 _mm256_store_ps(m0, _mm512_castps512_ps256(a.simdInternal_));
381 _mm256_store_pd(reinterpret_cast<double*>(m1), _mm512_extractf64x4_pd(_mm512_castps_pd(a.simdInternal_), 1));
384 static inline void gmx_simdcall
385 incrDualHsimd(float * m0,
389 assert(std::size_t(m0) % 32 == 0);
390 assert(std::size_t(m1) % 32 == 0);
395 x = _mm256_load_ps(m0);
396 x = _mm256_add_ps(x, _mm512_castps512_ps256(a.simdInternal_));
397 _mm256_store_ps(m0, x);
400 x = _mm256_load_ps(m1);
401 x = _mm256_add_ps(x, _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a.simdInternal_), 1)));
402 _mm256_store_ps(m1, x);
405 static inline void gmx_simdcall
411 assert(std::size_t(m) % 32 == 0);
413 a.simdInternal_ = _mm512_add_ps(a.simdInternal_, _mm512_shuffle_f32x4(a.simdInternal_, a.simdInternal_, 0xEE));
414 t = _mm256_load_ps(m);
415 t = _mm256_sub_ps(t, _mm512_castps512_ps256(a.simdInternal_));
416 _mm256_store_ps(m, t);
421 static inline void gmx_simdcall
422 gatherLoadTransposeHsimd(const float * base0,
424 const std::int32_t offset[],
431 assert(std::size_t(offset) % 32 == 0);
432 assert(std::size_t(base0) % 8 == 0);
433 assert(std::size_t(base1) % 8 == 0);
435 idx = _mm256_load_si256(reinterpret_cast<const __m256i*>(offset));
437 static_assert(align == 2 || align == 4, "If more are needed use fastMultiply");
440 idx = _mm256_slli_epi32(idx, 1);
443 tmp1 = _mm512_castpd_ps(_mm512_i32gather_pd(idx, reinterpret_cast<const double *>(base0), sizeof(double)));
444 tmp2 = _mm512_castpd_ps(_mm512_i32gather_pd(idx, reinterpret_cast<const double *>(base1), sizeof(double)));
446 v0->simdInternal_ = _mm512_mask_moveldup_ps(tmp1, 0xAAAA, tmp2);
447 v1->simdInternal_ = _mm512_mask_movehdup_ps(tmp2, 0x5555, tmp1);
449 v0->simdInternal_ = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), v0->simdInternal_);
450 v1->simdInternal_ = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), v1->simdInternal_);
453 static inline float gmx_simdcall
454 reduceIncr4ReturnSumHsimd(float * m,
461 assert(std::size_t(m) % 16 == 0);
463 t0 = _mm512_shuffle_f32x4(v0.simdInternal_, v1.simdInternal_, 0x88);
464 t1 = _mm512_shuffle_f32x4(v0.simdInternal_, v1.simdInternal_, 0xDD);
465 t0 = _mm512_add_ps(t0, t1);
466 t0 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0x4E));
467 t0 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0xB1));
468 t0 = _mm512_maskz_compress_ps(avx512Int2Mask(0x1111), t0);
470 t3 = _mm512_castps512_ps128(t0);
472 t2 = _mm_add_ps(t2, t3);
475 t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0x4E));
476 t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0xB1));
478 return _mm_cvtss_f32(t3);
481 static inline SimdFloat gmx_simdcall
482 loadUNDuplicate4(const float* f)
485 _mm512_permute_ps(_mm512_maskz_expandloadu_ps(0x1111, f), 0)
489 static inline SimdFloat gmx_simdcall
490 load4DuplicateN(const float* f)
493 _mm512_broadcast_f32x4(_mm_load_ps(f))
497 static inline SimdFloat gmx_simdcall
498 loadU4NOffset(const float* f, int offset)
500 const __m256i idx = _mm256_setr_epi32(0, 0, 1, 1, 2, 2, 3, 3);
501 const __m256i gdx = _mm256_add_epi32(_mm256_setr_epi32(0, 2, 0, 2, 0, 2, 0, 2),
502 _mm256_mullo_epi32(idx, _mm256_set1_epi32(offset)));
504 _mm512_castpd_ps(_mm512_i32gather_pd(gdx, reinterpret_cast<const double*>(f), sizeof(float)))
510 #endif // GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H