2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017,2018,2019, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 #ifndef GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H
37 #define GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H
44 #include <immintrin.h>
46 #include "gromacs/utility/basedefinitions.h"
48 #include "impl_x86_avx_512_general.h"
49 #include "impl_x86_avx_512_simd_float.h"
54 static const int c_simdBestPairAlignmentFloat = 2;
58 // Multiply function optimized for powers of 2, for which it is done by
59 // shifting. Currently up to 8 is accelerated. Could be accelerated for any
60 // number with a constexpr log2 function.
62 SimdFInt32 fastMultiply(SimdFInt32 x)
66 return _mm512_slli_epi32(x.simdInternal_, 1);
70 return _mm512_slli_epi32(x.simdInternal_, 2);
74 return _mm512_slli_epi32(x.simdInternal_, 3);
83 static inline void gmx_simdcall gatherLoadBySimdIntTranspose(const float*, SimdFInt32)
85 // Nothing to do. Termination of recursion.
89 template<int align, typename... Targs>
90 static inline void gmx_simdcall
91 gatherLoadBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v, Targs... Fargs)
93 // For align 1 or 2: No multiplication of offset is needed
96 offset = fastMultiply<align>(offset);
98 // For align 2: Scale of 2*sizeof(float) is used (maximum supported scale)
99 constexpr int align_ = (align > 2) ? 1 : align;
100 v->simdInternal_ = _mm512_i32gather_ps(offset.simdInternal_, base, sizeof(float) * align_);
101 // Gather remaining elements. Avoid extra multiplication (new align is 1 or 2).
102 gatherLoadBySimdIntTranspose<align_>(base + 1, offset, Fargs...);
105 template<int align, typename... Targs>
106 static inline void gmx_simdcall
107 gatherLoadUBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v, Targs... Fargs)
109 gatherLoadBySimdIntTranspose<align>(base, offset, v, Fargs...);
112 template<int align, typename... Targs>
113 static inline void gmx_simdcall
114 gatherLoadTranspose(const float* base, const std::int32_t offset[], SimdFloat* v, Targs... Fargs)
116 gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdFInt32Tag()), v, Fargs...);
119 template<int align, typename... Targs>
120 static inline void gmx_simdcall
121 gatherLoadUTranspose(const float* base, const std::int32_t offset[], SimdFloat* v, Targs... Fargs)
123 gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdFInt32Tag()), v, Fargs...);
127 static inline void gmx_simdcall
128 transposeScatterStoreU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
130 SimdFInt32 simdoffset = simdLoad(offset, SimdFInt32Tag());
133 simdoffset = fastMultiply<align>(simdoffset);
135 constexpr size_t scale = (align > 2) ? sizeof(float) : sizeof(float) * align;
137 _mm512_i32scatter_ps(base, simdoffset.simdInternal_, v0.simdInternal_, scale);
138 _mm512_i32scatter_ps(&(base[1]), simdoffset.simdInternal_, v1.simdInternal_, scale);
139 _mm512_i32scatter_ps(&(base[2]), simdoffset.simdInternal_, v2.simdInternal_, scale);
143 static inline void gmx_simdcall
144 transposeScatterIncrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
146 __m512 t[4], t5, t6, t7, t8;
148 alignas(GMX_SIMD_ALIGNMENT) std::int32_t o[16];
149 store(o, fastMultiply<align>(simdLoad(offset, SimdFInt32Tag())));
152 t5 = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
153 t6 = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
154 t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
155 t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
156 t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
157 t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
158 for (i = 0; i < 4; i++)
160 _mm512_mask_storeu_ps(base + o[i], avx512Int2Mask(7),
161 _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[i]),
162 _mm512_castps512_ps128(t[i]))));
163 _mm512_mask_storeu_ps(base + o[4 + i], avx512Int2Mask(7),
164 _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[4 + i]),
165 _mm512_extractf32x4_ps(t[i], 1))));
166 _mm512_mask_storeu_ps(base + o[8 + i], avx512Int2Mask(7),
167 _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[8 + i]),
168 _mm512_extractf32x4_ps(t[i], 2))));
169 _mm512_mask_storeu_ps(base + o[12 + i], avx512Int2Mask(7),
170 _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[12 + i]),
171 _mm512_extractf32x4_ps(t[i], 3))));
176 // One could use shuffle here too if it is OK to overwrite the padded elements for alignment
177 t5 = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
178 t6 = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
179 t7 = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
180 t8 = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
181 t[0] = _mm512_unpacklo_ps(t5, t7); // x0 y0 z0 0 | x4 y4 z4 0
182 t[1] = _mm512_unpackhi_ps(t5, t7); // x1 y1 z1 0 | x5 y5 z5 0
183 t[2] = _mm512_unpacklo_ps(t6, t8); // x2 y2 z2 0 | x6 y6 z6 0
184 t[3] = _mm512_unpackhi_ps(t6, t8); // x3 y3 z3 0 | x7 y7 z7 0
187 for (i = 0; i < 4; i++)
189 _mm_store_ps(base + o[i],
190 _mm_add_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
191 _mm_store_ps(base + o[4 + i],
192 _mm_add_ps(_mm_load_ps(base + o[4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
193 _mm_store_ps(base + o[8 + i],
194 _mm_add_ps(_mm_load_ps(base + o[8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
195 _mm_store_ps(base + o[12 + i], _mm_add_ps(_mm_load_ps(base + o[12 + i]),
196 _mm512_extractf32x4_ps(t[i], 3)));
201 for (i = 0; i < 4; i++)
203 _mm_storeu_ps(base + o[i],
204 _mm_add_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
205 _mm_storeu_ps(base + o[4 + i], _mm_add_ps(_mm_loadu_ps(base + o[4 + i]),
206 _mm512_extractf32x4_ps(t[i], 1)));
207 _mm_storeu_ps(base + o[8 + i], _mm_add_ps(_mm_loadu_ps(base + o[8 + i]),
208 _mm512_extractf32x4_ps(t[i], 2)));
209 _mm_storeu_ps(base + o[12 + i], _mm_add_ps(_mm_loadu_ps(base + o[12 + i]),
210 _mm512_extractf32x4_ps(t[i], 3)));
217 static inline void gmx_simdcall
218 transposeScatterDecrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
220 __m512 t[4], t5, t6, t7, t8;
222 alignas(GMX_SIMD_ALIGNMENT) std::int32_t o[16];
223 store(o, fastMultiply<align>(simdLoad(offset, SimdFInt32Tag())));
226 t5 = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
227 t6 = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
228 t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
229 t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
230 t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
231 t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
232 for (i = 0; i < 4; i++)
234 _mm512_mask_storeu_ps(base + o[i], avx512Int2Mask(7),
235 _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[i]),
236 _mm512_castps512_ps128(t[i]))));
237 _mm512_mask_storeu_ps(base + o[4 + i], avx512Int2Mask(7),
238 _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[4 + i]),
239 _mm512_extractf32x4_ps(t[i], 1))));
240 _mm512_mask_storeu_ps(base + o[8 + i], avx512Int2Mask(7),
241 _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[8 + i]),
242 _mm512_extractf32x4_ps(t[i], 2))));
243 _mm512_mask_storeu_ps(base + o[12 + i], avx512Int2Mask(7),
244 _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[12 + i]),
245 _mm512_extractf32x4_ps(t[i], 3))));
250 // One could use shuffle here too if it is OK to overwrite the padded elements for alignment
251 t5 = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
252 t6 = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
253 t7 = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
254 t8 = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
255 t[0] = _mm512_unpacklo_ps(t5, t7); // x0 y0 z0 0 | x4 y4 z4 0
256 t[1] = _mm512_unpackhi_ps(t5, t7); // x1 y1 z1 0 | x5 y5 z5 0
257 t[2] = _mm512_unpacklo_ps(t6, t8); // x2 y2 z2 0 | x6 y6 z6 0
258 t[3] = _mm512_unpackhi_ps(t6, t8); // x3 y3 z3 0 | x7 y7 z7 0
261 for (i = 0; i < 4; i++)
263 _mm_store_ps(base + o[i],
264 _mm_sub_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
265 _mm_store_ps(base + o[4 + i],
266 _mm_sub_ps(_mm_load_ps(base + o[4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
267 _mm_store_ps(base + o[8 + i],
268 _mm_sub_ps(_mm_load_ps(base + o[8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
269 _mm_store_ps(base + o[12 + i], _mm_sub_ps(_mm_load_ps(base + o[12 + i]),
270 _mm512_extractf32x4_ps(t[i], 3)));
275 for (i = 0; i < 4; i++)
277 _mm_storeu_ps(base + o[i],
278 _mm_sub_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
279 _mm_storeu_ps(base + o[4 + i], _mm_sub_ps(_mm_loadu_ps(base + o[4 + i]),
280 _mm512_extractf32x4_ps(t[i], 1)));
281 _mm_storeu_ps(base + o[8 + i], _mm_sub_ps(_mm_loadu_ps(base + o[8 + i]),
282 _mm512_extractf32x4_ps(t[i], 2)));
283 _mm_storeu_ps(base + o[12 + i], _mm_sub_ps(_mm_loadu_ps(base + o[12 + i]),
284 _mm512_extractf32x4_ps(t[i], 3)));
290 static inline void gmx_simdcall expandScalarsToTriplets(SimdFloat scalar,
291 SimdFloat* triplets0,
292 SimdFloat* triplets1,
293 SimdFloat* triplets2)
295 triplets0->simdInternal_ = _mm512_permutexvar_ps(
296 _mm512_set_epi32(5, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1, 0, 0, 0), scalar.simdInternal_);
297 triplets1->simdInternal_ = _mm512_permutexvar_ps(
298 _mm512_set_epi32(10, 10, 9, 9, 9, 8, 8, 8, 7, 7, 7, 6, 6, 6, 5, 5), scalar.simdInternal_);
299 triplets2->simdInternal_ = _mm512_permutexvar_ps(
300 _mm512_set_epi32(15, 15, 15, 14, 14, 14, 13, 13, 13, 12, 12, 12, 11, 11, 11, 10),
301 scalar.simdInternal_);
305 static inline float gmx_simdcall reduceIncr4ReturnSum(float* m, SimdFloat v0, SimdFloat v1, SimdFloat v2, SimdFloat v3)
310 assert(std::size_t(m) % 16 == 0);
312 t0 = _mm512_add_ps(v0.simdInternal_, _mm512_permute_ps(v0.simdInternal_, 0x4E));
313 t0 = _mm512_mask_add_ps(t0, avx512Int2Mask(0xCCCC), v2.simdInternal_,
314 _mm512_permute_ps(v2.simdInternal_, 0x4E));
315 t1 = _mm512_add_ps(v1.simdInternal_, _mm512_permute_ps(v1.simdInternal_, 0x4E));
316 t1 = _mm512_mask_add_ps(t1, avx512Int2Mask(0xCCCC), v3.simdInternal_,
317 _mm512_permute_ps(v3.simdInternal_, 0x4E));
318 t2 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0xB1));
319 t2 = _mm512_mask_add_ps(t2, avx512Int2Mask(0xAAAA), t1, _mm512_permute_ps(t1, 0xB1));
321 t2 = _mm512_add_ps(t2, _mm512_shuffle_f32x4(t2, t2, 0x4E));
322 t2 = _mm512_add_ps(t2, _mm512_shuffle_f32x4(t2, t2, 0xB1));
324 t3 = _mm512_castps512_ps128(t2);
326 t4 = _mm_add_ps(t4, t3);
329 t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0x4E));
330 t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0xB1));
332 return _mm_cvtss_f32(t3);
335 static inline SimdFloat gmx_simdcall loadDualHsimd(const float* m0, const float* m1)
337 assert(std::size_t(m0) % 32 == 0);
338 assert(std::size_t(m1) % 32 == 0);
340 return { _mm512_castpd_ps(_mm512_insertf64x4(
341 _mm512_castpd256_pd512(_mm256_load_pd(reinterpret_cast<const double*>(m0))),
342 _mm256_load_pd(reinterpret_cast<const double*>(m1)), 1)) };
345 static inline SimdFloat gmx_simdcall loadDuplicateHsimd(const float* m)
347 assert(std::size_t(m) % 32 == 0);
348 return { _mm512_castpd_ps(_mm512_broadcast_f64x4(_mm256_load_pd(reinterpret_cast<const double*>(m)))) };
351 static inline SimdFloat gmx_simdcall loadU1DualHsimd(const float* m)
353 return { _mm512_shuffle_f32x4(_mm512_broadcastss_ps(_mm_load_ss(m)),
354 _mm512_broadcastss_ps(_mm_load_ss(m + 1)), 0x44) };
358 static inline void gmx_simdcall storeDualHsimd(float* m0, float* m1, SimdFloat a)
360 assert(std::size_t(m0) % 32 == 0);
361 assert(std::size_t(m1) % 32 == 0);
363 _mm256_store_ps(m0, _mm512_castps512_ps256(a.simdInternal_));
364 _mm256_store_pd(reinterpret_cast<double*>(m1),
365 _mm512_extractf64x4_pd(_mm512_castps_pd(a.simdInternal_), 1));
368 static inline void gmx_simdcall incrDualHsimd(float* m0, float* m1, SimdFloat a)
370 assert(std::size_t(m0) % 32 == 0);
371 assert(std::size_t(m1) % 32 == 0);
376 x = _mm256_load_ps(m0);
377 x = _mm256_add_ps(x, _mm512_castps512_ps256(a.simdInternal_));
378 _mm256_store_ps(m0, x);
381 x = _mm256_load_ps(m1);
382 x = _mm256_add_ps(x, _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a.simdInternal_), 1)));
383 _mm256_store_ps(m1, x);
386 static inline void gmx_simdcall decrHsimd(float* m, SimdFloat a)
390 assert(std::size_t(m) % 32 == 0);
392 a.simdInternal_ = _mm512_add_ps(a.simdInternal_,
393 _mm512_shuffle_f32x4(a.simdInternal_, a.simdInternal_, 0xEE));
394 t = _mm256_load_ps(m);
395 t = _mm256_sub_ps(t, _mm512_castps512_ps256(a.simdInternal_));
396 _mm256_store_ps(m, t);
401 static inline void gmx_simdcall gatherLoadTransposeHsimd(const float* base0,
403 const std::int32_t offset[],
410 assert(std::size_t(offset) % 32 == 0);
411 assert(std::size_t(base0) % 8 == 0);
412 assert(std::size_t(base1) % 8 == 0);
414 idx = _mm256_load_si256(reinterpret_cast<const __m256i*>(offset));
416 static_assert(align == 2 || align == 4, "If more are needed use fastMultiply");
419 idx = _mm256_slli_epi32(idx, 1);
422 tmp1 = _mm512_castpd_ps(
423 _mm512_i32gather_pd(idx, reinterpret_cast<const double*>(base0), sizeof(double)));
424 tmp2 = _mm512_castpd_ps(
425 _mm512_i32gather_pd(idx, reinterpret_cast<const double*>(base1), sizeof(double)));
427 v0->simdInternal_ = _mm512_mask_moveldup_ps(tmp1, 0xAAAA, tmp2);
428 v1->simdInternal_ = _mm512_mask_movehdup_ps(tmp2, 0x5555, tmp1);
430 v0->simdInternal_ = _mm512_permutexvar_ps(
431 _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), v0->simdInternal_);
432 v1->simdInternal_ = _mm512_permutexvar_ps(
433 _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), v1->simdInternal_);
436 static inline float gmx_simdcall reduceIncr4ReturnSumHsimd(float* m, SimdFloat v0, SimdFloat v1)
441 assert(std::size_t(m) % 16 == 0);
443 t0 = _mm512_shuffle_f32x4(v0.simdInternal_, v1.simdInternal_, 0x88);
444 t1 = _mm512_shuffle_f32x4(v0.simdInternal_, v1.simdInternal_, 0xDD);
445 t0 = _mm512_add_ps(t0, t1);
446 t0 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0x4E));
447 t0 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0xB1));
448 t0 = _mm512_maskz_compress_ps(avx512Int2Mask(0x1111), t0);
450 t3 = _mm512_castps512_ps128(t0);
452 t2 = _mm_add_ps(t2, t3);
455 t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0x4E));
456 t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0xB1));
458 return _mm_cvtss_f32(t3);
461 static inline SimdFloat gmx_simdcall loadUNDuplicate4(const float* f)
463 return { _mm512_permute_ps(_mm512_maskz_expandloadu_ps(0x1111, f), 0) };
466 static inline SimdFloat gmx_simdcall load4DuplicateN(const float* f)
468 return { _mm512_broadcast_f32x4(_mm_load_ps(f)) };
471 static inline SimdFloat gmx_simdcall loadU4NOffset(const float* f, int offset)
473 const __m256i idx = _mm256_setr_epi32(0, 0, 1, 1, 2, 2, 3, 3);
474 const __m256i gdx = _mm256_add_epi32(_mm256_setr_epi32(0, 2, 0, 2, 0, 2, 0, 2),
475 _mm256_mullo_epi32(idx, _mm256_set1_epi32(offset)));
476 return { _mm512_castpd_ps(_mm512_i32gather_pd(gdx, reinterpret_cast<const double*>(f), sizeof(float))) };
481 #endif // GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H