2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017,2018 by the GROMACS development team.
5 * Copyright (c) 2019,2020, by the GROMACS development team, led by
6 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7 * and including many others, as listed in the AUTHORS file in the
8 * top-level source directory and at http://www.gromacs.org.
10 * GROMACS is free software; you can redistribute it and/or
11 * modify it under the terms of the GNU Lesser General Public License
12 * as published by the Free Software Foundation; either version 2.1
13 * of the License, or (at your option) any later version.
15 * GROMACS is distributed in the hope that it will be useful,
16 * but WITHOUT ANY WARRANTY; without even the implied warranty of
17 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18 * Lesser General Public License for more details.
20 * You should have received a copy of the GNU Lesser General Public
21 * License along with GROMACS; if not, see
22 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
25 * If you want to redistribute modifications to GROMACS, please
26 * consider that scientific software is very special. Version
27 * control is crucial - bugs must be traceable. We will be happy to
28 * consider code for inclusion in the official distribution, but
29 * derived work must not be called official GROMACS. Details are found
30 * in the README & COPYING files - if they are missing, get the
31 * official version at http://www.gromacs.org.
33 * To help us fund GROMACS development, we humbly ask that you cite
34 * the research papers on the package. Check out http://www.gromacs.org.
37 #ifndef GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H
38 #define GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H
45 #include <immintrin.h>
47 #include "gromacs/utility/basedefinitions.h"
49 #include "impl_x86_avx_512_general.h"
50 #include "impl_x86_avx_512_simd_float.h"
55 static const int c_simdBestPairAlignmentFloat = 2;
59 // Multiply function optimized for powers of 2, for which it is done by
60 // shifting. Currently up to 8 is accelerated. Could be accelerated for any
61 // number with a constexpr log2 function.
63 static inline SimdFInt32 fastMultiply(SimdFInt32 x)
67 return _mm512_slli_epi32(x.simdInternal_, 1);
71 return _mm512_slli_epi32(x.simdInternal_, 2);
75 return _mm512_slli_epi32(x.simdInternal_, 3);
84 static inline void gmx_simdcall gatherLoadBySimdIntTranspose(const float*, SimdFInt32)
86 // Nothing to do. Termination of recursion.
89 /* This is an internal helper function used by decr3Hsimd(...).
91 inline void gmx_simdcall decrHsimd(float* m, SimdFloat a)
95 assert(std::size_t(m) % 32 == 0);
97 a.simdInternal_ = _mm512_add_ps(a.simdInternal_,
98 _mm512_shuffle_f32x4(a.simdInternal_, a.simdInternal_, 0xEE));
99 t = _mm256_load_ps(m);
100 t = _mm256_sub_ps(t, _mm512_castps512_ps256(a.simdInternal_));
101 _mm256_store_ps(m, t);
105 template<int align, typename... Targs>
106 static inline void gmx_simdcall
107 gatherLoadBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v, Targs... Fargs)
109 // For align 1 or 2: No multiplication of offset is needed
112 offset = fastMultiply<align>(offset);
114 // For align 2: Scale of 2*sizeof(float) is used (maximum supported scale)
115 constexpr int align_ = (align > 2) ? 1 : align;
116 v->simdInternal_ = _mm512_i32gather_ps(offset.simdInternal_, base, sizeof(float) * align_);
117 // Gather remaining elements. Avoid extra multiplication (new align is 1 or 2).
118 gatherLoadBySimdIntTranspose<align_>(base + 1, offset, Fargs...);
121 template<int align, typename... Targs>
122 static inline void gmx_simdcall
123 gatherLoadUBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v, Targs... Fargs)
125 gatherLoadBySimdIntTranspose<align>(base, offset, v, Fargs...);
128 template<int align, typename... Targs>
129 static inline void gmx_simdcall
130 gatherLoadTranspose(const float* base, const std::int32_t offset[], SimdFloat* v, Targs... Fargs)
132 gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdFInt32Tag()), v, Fargs...);
135 template<int align, typename... Targs>
136 static inline void gmx_simdcall
137 gatherLoadUTranspose(const float* base, const std::int32_t offset[], SimdFloat* v, Targs... Fargs)
139 gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdFInt32Tag()), v, Fargs...);
143 static inline void gmx_simdcall
144 transposeScatterStoreU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
146 SimdFInt32 simdoffset = simdLoad(offset, SimdFInt32Tag());
149 simdoffset = fastMultiply<align>(simdoffset);
151 constexpr size_t scale = (align > 2) ? sizeof(float) : sizeof(float) * align;
153 _mm512_i32scatter_ps(base, simdoffset.simdInternal_, v0.simdInternal_, scale);
154 _mm512_i32scatter_ps(&(base[1]), simdoffset.simdInternal_, v1.simdInternal_, scale);
155 _mm512_i32scatter_ps(&(base[2]), simdoffset.simdInternal_, v2.simdInternal_, scale);
159 static inline void gmx_simdcall
160 transposeScatterIncrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
162 __m512 t[4], t5, t6, t7, t8;
164 alignas(GMX_SIMD_ALIGNMENT) std::int32_t o[16];
165 store(o, fastMultiply<align>(simdLoad(offset, SimdFInt32Tag())));
168 t5 = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
169 t6 = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
170 t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
171 t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
172 t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
173 t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
174 for (i = 0; i < 4; i++)
176 _mm512_mask_storeu_ps(base + o[i], avx512Int2Mask(7),
177 _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[i]),
178 _mm512_castps512_ps128(t[i]))));
179 _mm512_mask_storeu_ps(base + o[4 + i], avx512Int2Mask(7),
180 _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[4 + i]),
181 _mm512_extractf32x4_ps(t[i], 1))));
182 _mm512_mask_storeu_ps(base + o[8 + i], avx512Int2Mask(7),
183 _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[8 + i]),
184 _mm512_extractf32x4_ps(t[i], 2))));
185 _mm512_mask_storeu_ps(base + o[12 + i], avx512Int2Mask(7),
186 _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[12 + i]),
187 _mm512_extractf32x4_ps(t[i], 3))));
192 // One could use shuffle here too if it is OK to overwrite the padded elements for alignment
193 t5 = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
194 t6 = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
195 t7 = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
196 t8 = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
197 t[0] = _mm512_unpacklo_ps(t5, t7); // x0 y0 z0 0 | x4 y4 z4 0
198 t[1] = _mm512_unpackhi_ps(t5, t7); // x1 y1 z1 0 | x5 y5 z5 0
199 t[2] = _mm512_unpacklo_ps(t6, t8); // x2 y2 z2 0 | x6 y6 z6 0
200 t[3] = _mm512_unpackhi_ps(t6, t8); // x3 y3 z3 0 | x7 y7 z7 0
203 for (i = 0; i < 4; i++)
205 _mm_store_ps(base + o[i],
206 _mm_add_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
207 _mm_store_ps(base + o[4 + i],
208 _mm_add_ps(_mm_load_ps(base + o[4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
209 _mm_store_ps(base + o[8 + i],
210 _mm_add_ps(_mm_load_ps(base + o[8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
211 _mm_store_ps(base + o[12 + i], _mm_add_ps(_mm_load_ps(base + o[12 + i]),
212 _mm512_extractf32x4_ps(t[i], 3)));
217 for (i = 0; i < 4; i++)
219 _mm_storeu_ps(base + o[i],
220 _mm_add_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
221 _mm_storeu_ps(base + o[4 + i], _mm_add_ps(_mm_loadu_ps(base + o[4 + i]),
222 _mm512_extractf32x4_ps(t[i], 1)));
223 _mm_storeu_ps(base + o[8 + i], _mm_add_ps(_mm_loadu_ps(base + o[8 + i]),
224 _mm512_extractf32x4_ps(t[i], 2)));
225 _mm_storeu_ps(base + o[12 + i], _mm_add_ps(_mm_loadu_ps(base + o[12 + i]),
226 _mm512_extractf32x4_ps(t[i], 3)));
233 static inline void gmx_simdcall
234 transposeScatterDecrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
236 __m512 t[4], t5, t6, t7, t8;
238 alignas(GMX_SIMD_ALIGNMENT) std::int32_t o[16];
239 store(o, fastMultiply<align>(simdLoad(offset, SimdFInt32Tag())));
242 t5 = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
243 t6 = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
244 t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
245 t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
246 t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
247 t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
248 for (i = 0; i < 4; i++)
250 _mm512_mask_storeu_ps(base + o[i], avx512Int2Mask(7),
251 _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[i]),
252 _mm512_castps512_ps128(t[i]))));
253 _mm512_mask_storeu_ps(base + o[4 + i], avx512Int2Mask(7),
254 _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[4 + i]),
255 _mm512_extractf32x4_ps(t[i], 1))));
256 _mm512_mask_storeu_ps(base + o[8 + i], avx512Int2Mask(7),
257 _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[8 + i]),
258 _mm512_extractf32x4_ps(t[i], 2))));
259 _mm512_mask_storeu_ps(base + o[12 + i], avx512Int2Mask(7),
260 _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[12 + i]),
261 _mm512_extractf32x4_ps(t[i], 3))));
266 // One could use shuffle here too if it is OK to overwrite the padded elements for alignment
267 t5 = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
268 t6 = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
269 t7 = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
270 t8 = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
271 t[0] = _mm512_unpacklo_ps(t5, t7); // x0 y0 z0 0 | x4 y4 z4 0
272 t[1] = _mm512_unpackhi_ps(t5, t7); // x1 y1 z1 0 | x5 y5 z5 0
273 t[2] = _mm512_unpacklo_ps(t6, t8); // x2 y2 z2 0 | x6 y6 z6 0
274 t[3] = _mm512_unpackhi_ps(t6, t8); // x3 y3 z3 0 | x7 y7 z7 0
277 for (i = 0; i < 4; i++)
279 _mm_store_ps(base + o[i],
280 _mm_sub_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
281 _mm_store_ps(base + o[4 + i],
282 _mm_sub_ps(_mm_load_ps(base + o[4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
283 _mm_store_ps(base + o[8 + i],
284 _mm_sub_ps(_mm_load_ps(base + o[8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
285 _mm_store_ps(base + o[12 + i], _mm_sub_ps(_mm_load_ps(base + o[12 + i]),
286 _mm512_extractf32x4_ps(t[i], 3)));
291 for (i = 0; i < 4; i++)
293 _mm_storeu_ps(base + o[i],
294 _mm_sub_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
295 _mm_storeu_ps(base + o[4 + i], _mm_sub_ps(_mm_loadu_ps(base + o[4 + i]),
296 _mm512_extractf32x4_ps(t[i], 1)));
297 _mm_storeu_ps(base + o[8 + i], _mm_sub_ps(_mm_loadu_ps(base + o[8 + i]),
298 _mm512_extractf32x4_ps(t[i], 2)));
299 _mm_storeu_ps(base + o[12 + i], _mm_sub_ps(_mm_loadu_ps(base + o[12 + i]),
300 _mm512_extractf32x4_ps(t[i], 3)));
306 static inline void gmx_simdcall expandScalarsToTriplets(SimdFloat scalar,
307 SimdFloat* triplets0,
308 SimdFloat* triplets1,
309 SimdFloat* triplets2)
311 triplets0->simdInternal_ = _mm512_permutexvar_ps(
312 _mm512_set_epi32(5, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1, 0, 0, 0), scalar.simdInternal_);
313 triplets1->simdInternal_ = _mm512_permutexvar_ps(
314 _mm512_set_epi32(10, 10, 9, 9, 9, 8, 8, 8, 7, 7, 7, 6, 6, 6, 5, 5), scalar.simdInternal_);
315 triplets2->simdInternal_ = _mm512_permutexvar_ps(
316 _mm512_set_epi32(15, 15, 15, 14, 14, 14, 13, 13, 13, 12, 12, 12, 11, 11, 11, 10),
317 scalar.simdInternal_);
321 static inline float gmx_simdcall reduceIncr4ReturnSum(float* m, SimdFloat v0, SimdFloat v1, SimdFloat v2, SimdFloat v3)
326 assert(std::size_t(m) % 16 == 0);
328 t0 = _mm512_add_ps(v0.simdInternal_, _mm512_permute_ps(v0.simdInternal_, 0x4E));
329 t0 = _mm512_mask_add_ps(t0, avx512Int2Mask(0xCCCC), v2.simdInternal_,
330 _mm512_permute_ps(v2.simdInternal_, 0x4E));
331 t1 = _mm512_add_ps(v1.simdInternal_, _mm512_permute_ps(v1.simdInternal_, 0x4E));
332 t1 = _mm512_mask_add_ps(t1, avx512Int2Mask(0xCCCC), v3.simdInternal_,
333 _mm512_permute_ps(v3.simdInternal_, 0x4E));
334 t2 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0xB1));
335 t2 = _mm512_mask_add_ps(t2, avx512Int2Mask(0xAAAA), t1, _mm512_permute_ps(t1, 0xB1));
337 t2 = _mm512_add_ps(t2, _mm512_shuffle_f32x4(t2, t2, 0x4E));
338 t2 = _mm512_add_ps(t2, _mm512_shuffle_f32x4(t2, t2, 0xB1));
340 t3 = _mm512_castps512_ps128(t2);
342 t4 = _mm_add_ps(t4, t3);
345 t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0x4E));
346 t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0xB1));
348 return _mm_cvtss_f32(t3);
351 static inline SimdFloat gmx_simdcall loadDualHsimd(const float* m0, const float* m1)
353 assert(std::size_t(m0) % 32 == 0);
354 assert(std::size_t(m1) % 32 == 0);
356 return { _mm512_castpd_ps(_mm512_insertf64x4(
357 _mm512_castpd256_pd512(_mm256_load_pd(reinterpret_cast<const double*>(m0))),
358 _mm256_load_pd(reinterpret_cast<const double*>(m1)), 1)) };
361 static inline SimdFloat gmx_simdcall loadDuplicateHsimd(const float* m)
363 assert(std::size_t(m) % 32 == 0);
364 return { _mm512_castpd_ps(_mm512_broadcast_f64x4(_mm256_load_pd(reinterpret_cast<const double*>(m)))) };
367 static inline SimdFloat gmx_simdcall loadU1DualHsimd(const float* m)
369 return { _mm512_shuffle_f32x4(_mm512_broadcastss_ps(_mm_load_ss(m)),
370 _mm512_broadcastss_ps(_mm_load_ss(m + 1)), 0x44) };
374 static inline void gmx_simdcall storeDualHsimd(float* m0, float* m1, SimdFloat a)
376 assert(std::size_t(m0) % 32 == 0);
377 assert(std::size_t(m1) % 32 == 0);
379 _mm256_store_ps(m0, _mm512_castps512_ps256(a.simdInternal_));
380 _mm256_store_pd(reinterpret_cast<double*>(m1),
381 _mm512_extractf64x4_pd(_mm512_castps_pd(a.simdInternal_), 1));
384 static inline void gmx_simdcall incrDualHsimd(float* m0, float* m1, SimdFloat a)
386 assert(std::size_t(m0) % 32 == 0);
387 assert(std::size_t(m1) % 32 == 0);
392 x = _mm256_load_ps(m0);
393 x = _mm256_add_ps(x, _mm512_castps512_ps256(a.simdInternal_));
394 _mm256_store_ps(m0, x);
397 x = _mm256_load_ps(m1);
398 x = _mm256_add_ps(x, _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a.simdInternal_), 1)));
399 _mm256_store_ps(m1, x);
402 static inline void gmx_simdcall decr3Hsimd(float* m, SimdFloat a0, SimdFloat a1, SimdFloat a2)
405 decrHsimd(m + GMX_SIMD_FLOAT_WIDTH / 2, a1);
406 decrHsimd(m + GMX_SIMD_FLOAT_WIDTH, a2);
411 static inline void gmx_simdcall gatherLoadTransposeHsimd(const float* base0,
413 const std::int32_t offset[],
420 assert(std::size_t(offset) % 32 == 0);
421 assert(std::size_t(base0) % 8 == 0);
422 assert(std::size_t(base1) % 8 == 0);
424 idx = _mm256_load_si256(reinterpret_cast<const __m256i*>(offset));
426 static_assert(align == 2 || align == 4, "If more are needed use fastMultiply");
429 idx = _mm256_slli_epi32(idx, 1);
432 tmp1 = _mm512_castpd_ps(
433 _mm512_i32gather_pd(idx, reinterpret_cast<const double*>(base0), sizeof(double)));
434 tmp2 = _mm512_castpd_ps(
435 _mm512_i32gather_pd(idx, reinterpret_cast<const double*>(base1), sizeof(double)));
437 v0->simdInternal_ = _mm512_mask_moveldup_ps(tmp1, 0xAAAA, tmp2);
438 v1->simdInternal_ = _mm512_mask_movehdup_ps(tmp2, 0x5555, tmp1);
440 v0->simdInternal_ = _mm512_permutexvar_ps(
441 _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), v0->simdInternal_);
442 v1->simdInternal_ = _mm512_permutexvar_ps(
443 _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), v1->simdInternal_);
446 static inline float gmx_simdcall reduceIncr4ReturnSumHsimd(float* m, SimdFloat v0, SimdFloat v1)
451 assert(std::size_t(m) % 16 == 0);
453 t0 = _mm512_shuffle_f32x4(v0.simdInternal_, v1.simdInternal_, 0x88);
454 t1 = _mm512_shuffle_f32x4(v0.simdInternal_, v1.simdInternal_, 0xDD);
455 t0 = _mm512_add_ps(t0, t1);
456 t0 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0x4E));
457 t0 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0xB1));
458 t0 = _mm512_maskz_compress_ps(avx512Int2Mask(0x1111), t0);
460 t3 = _mm512_castps512_ps128(t0);
462 t2 = _mm_add_ps(t2, t3);
465 t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0x4E));
466 t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0xB1));
468 return _mm_cvtss_f32(t3);
471 static inline SimdFloat gmx_simdcall loadUNDuplicate4(const float* f)
473 return { _mm512_permute_ps(_mm512_maskz_expandloadu_ps(0x1111, f), 0) };
476 static inline SimdFloat gmx_simdcall load4DuplicateN(const float* f)
478 return { _mm512_broadcast_f32x4(_mm_load_ps(f)) };
481 static inline SimdFloat gmx_simdcall loadU4NOffset(const float* f, int offset)
483 const __m256i idx = _mm256_setr_epi32(0, 0, 1, 1, 2, 2, 3, 3);
484 const __m256i gdx = _mm256_add_epi32(_mm256_setr_epi32(0, 2, 0, 2, 0, 2, 0, 2),
485 _mm256_mullo_epi32(idx, _mm256_set1_epi32(offset)));
486 return { _mm512_castpd_ps(_mm512_i32gather_pd(gdx, reinterpret_cast<const double*>(f), sizeof(float))) };
491 #endif // GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H