2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014-2018, The GROMACS development team.
5 * Copyright (c) 2019,2020, by the GROMACS development team, led by
6 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7 * and including many others, as listed in the AUTHORS file in the
8 * top-level source directory and at http://www.gromacs.org.
10 * GROMACS is free software; you can redistribute it and/or
11 * modify it under the terms of the GNU Lesser General Public License
12 * as published by the Free Software Foundation; either version 2.1
13 * of the License, or (at your option) any later version.
15 * GROMACS is distributed in the hope that it will be useful,
16 * but WITHOUT ANY WARRANTY; without even the implied warranty of
17 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18 * Lesser General Public License for more details.
20 * You should have received a copy of the GNU Lesser General Public
21 * License along with GROMACS; if not, see
22 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
25 * If you want to redistribute modifications to GROMACS, please
26 * consider that scientific software is very special. Version
27 * control is crucial - bugs must be traceable. We will be happy to
28 * consider code for inclusion in the official distribution, but
29 * derived work must not be called official GROMACS. Details are found
30 * in the README & COPYING files - if they are missing, get the
31 * official version at http://www.gromacs.org.
33 * To help us fund GROMACS development, we humbly ask that you cite
34 * the research papers on the package. Check out http://www.gromacs.org.
37 #ifndef GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
38 #define GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
45 #include <immintrin.h>
47 #include "gromacs/utility/basedefinitions.h"
49 #include "impl_x86_avx_512_general.h"
50 #include "impl_x86_avx_512_simd_double.h"
55 static const int c_simdBestPairAlignmentDouble = 2;
59 // Multiply function optimized for powers of 2, for which it is done by
60 // shifting. Currently up to 8 is accelerated. Could be accelerated for any
61 // number with a constexpr log2 function.
63 static inline SimdDInt32 fastMultiply(SimdDInt32 x)
67 return _mm256_slli_epi32(x.simdInternal_, 1);
71 return _mm256_slli_epi32(x.simdInternal_, 2);
75 return _mm256_slli_epi32(x.simdInternal_, 3);
84 static inline void gmx_simdcall gatherLoadBySimdIntTranspose(const double*, SimdDInt32)
86 // Nothing to do. Termination of recursion.
89 /* This is an internal helper function used by decr3Hsimd(...).
91 inline void gmx_simdcall decrHsimd(double* m, SimdDouble a)
95 assert(std::size_t(m) % 32 == 0);
97 a.simdInternal_ = _mm512_add_pd(a.simdInternal_,
98 _mm512_shuffle_f64x2(a.simdInternal_, a.simdInternal_, 0xEE));
99 t = _mm256_load_pd(m);
100 t = _mm256_sub_pd(t, _mm512_castpd512_pd256(a.simdInternal_));
101 _mm256_store_pd(m, t);
106 template<int align, typename... Targs>
107 static inline void gmx_simdcall
108 gatherLoadBySimdIntTranspose(const double* base, SimdDInt32 offset, SimdDouble* v, Targs... Fargs)
112 offset = fastMultiply<align>(offset);
114 constexpr size_t scale = sizeof(double);
115 v->simdInternal_ = _mm512_i32gather_pd(offset.simdInternal_, base, scale);
116 gatherLoadBySimdIntTranspose<1>(base + 1, offset, Fargs...);
119 template<int align, typename... Targs>
120 static inline void gmx_simdcall
121 gatherLoadUBySimdIntTranspose(const double* base, SimdDInt32 offset, SimdDouble* v, Targs... Fargs)
123 gatherLoadBySimdIntTranspose<align>(base, offset, v, Fargs...);
126 template<int align, typename... Targs>
127 static inline void gmx_simdcall
128 gatherLoadTranspose(const double* base, const std::int32_t offset[], SimdDouble* v, Targs... Fargs)
130 gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdDInt32Tag()), v, Fargs...);
133 template<int align, typename... Targs>
134 static inline void gmx_simdcall
135 gatherLoadUTranspose(const double* base, const std::int32_t offset[], SimdDouble* v, Targs... Fargs)
137 gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdDInt32Tag()), v, Fargs...);
141 static inline void gmx_simdcall transposeScatterStoreU(double* base,
142 const std::int32_t offset[],
147 SimdDInt32 simdoffset = simdLoad(offset, SimdDInt32Tag());
151 simdoffset = fastMultiply<align>(simdoffset);
154 constexpr size_t scale = sizeof(double);
155 _mm512_i32scatter_pd(base, simdoffset.simdInternal_, v0.simdInternal_, scale);
156 _mm512_i32scatter_pd(&(base[1]), simdoffset.simdInternal_, v1.simdInternal_, scale);
157 _mm512_i32scatter_pd(&(base[2]), simdoffset.simdInternal_, v2.simdInternal_, scale);
161 static inline void gmx_simdcall
162 transposeScatterIncrU(double* base, const std::int32_t offset[], SimdDouble v0, SimdDouble v1, SimdDouble v2)
164 __m512d t[4], t5, t6, t7, t8;
165 alignas(GMX_SIMD_ALIGNMENT) std::int64_t o[8];
166 // TODO: should use fastMultiply
167 _mm512_store_epi64(o, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(
168 _mm256_load_si256((const __m256i*)(offset)), _mm256_set1_epi32(align))));
169 t5 = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
170 t6 = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
171 t7 = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
172 t8 = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
173 t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
174 t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
175 t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
176 t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
179 for (int i = 0; i < 4; i++)
181 _mm512_mask_storeu_pd(base + o[0 + i], avx512Int2Mask(7),
182 _mm512_castpd256_pd512(_mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]),
183 _mm512_castpd512_pd256(t[i]))));
184 _mm512_mask_storeu_pd(
185 base + o[4 + i], avx512Int2Mask(7),
186 _mm512_castpd256_pd512(_mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]),
187 _mm512_extractf64x4_pd(t[i], 1))));
194 for (int i = 0; i < 4; i++)
196 _mm256_store_pd(base + o[0 + i], _mm256_add_pd(_mm256_load_pd(base + o[0 + i]),
197 _mm512_castpd512_pd256(t[i])));
198 _mm256_store_pd(base + o[4 + i], _mm256_add_pd(_mm256_load_pd(base + o[4 + i]),
199 _mm512_extractf64x4_pd(t[i], 1)));
204 for (int i = 0; i < 4; i++)
206 _mm256_storeu_pd(base + o[0 + i], _mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]),
207 _mm512_castpd512_pd256(t[i])));
208 _mm256_storeu_pd(base + o[4 + i], _mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]),
209 _mm512_extractf64x4_pd(t[i], 1)));
216 static inline void gmx_simdcall
217 transposeScatterDecrU(double* base, const std::int32_t offset[], SimdDouble v0, SimdDouble v1, SimdDouble v2)
219 __m512d t[4], t5, t6, t7, t8;
220 alignas(GMX_SIMD_ALIGNMENT) std::int64_t o[8];
221 // TODO: should use fastMultiply
222 _mm512_store_epi64(o, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(
223 _mm256_load_si256((const __m256i*)(offset)), _mm256_set1_epi32(align))));
224 t5 = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
225 t6 = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
226 t7 = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
227 t8 = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
228 t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
229 t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
230 t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
231 t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
234 for (int i = 0; i < 4; i++)
236 _mm512_mask_storeu_pd(base + o[0 + i], avx512Int2Mask(7),
237 _mm512_castpd256_pd512(_mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]),
238 _mm512_castpd512_pd256(t[i]))));
239 _mm512_mask_storeu_pd(
240 base + o[4 + i], avx512Int2Mask(7),
241 _mm512_castpd256_pd512(_mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]),
242 _mm512_extractf64x4_pd(t[i], 1))));
249 for (int i = 0; i < 4; i++)
251 _mm256_store_pd(base + o[0 + i], _mm256_sub_pd(_mm256_load_pd(base + o[0 + i]),
252 _mm512_castpd512_pd256(t[i])));
253 _mm256_store_pd(base + o[4 + i], _mm256_sub_pd(_mm256_load_pd(base + o[4 + i]),
254 _mm512_extractf64x4_pd(t[i], 1)));
259 for (int i = 0; i < 4; i++)
261 _mm256_storeu_pd(base + o[0 + i], _mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]),
262 _mm512_castpd512_pd256(t[i])));
263 _mm256_storeu_pd(base + o[4 + i], _mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]),
264 _mm512_extractf64x4_pd(t[i], 1)));
270 static inline void gmx_simdcall expandScalarsToTriplets(SimdDouble scalar,
271 SimdDouble* triplets0,
272 SimdDouble* triplets1,
273 SimdDouble* triplets2)
275 triplets0->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(
276 _mm512_set_epi32(5, 4, 5, 4, 3, 2, 3, 2, 3, 2, 1, 0, 1, 0, 1, 0),
277 _mm512_castpd_si512(scalar.simdInternal_)));
278 triplets1->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(
279 _mm512_set_epi32(11, 10, 9, 8, 9, 8, 9, 8, 7, 6, 7, 6, 7, 6, 5, 4),
280 _mm512_castpd_si512(scalar.simdInternal_)));
281 triplets2->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(
282 _mm512_set_epi32(15, 14, 15, 14, 15, 14, 13, 12, 13, 12, 13, 12, 11, 10, 11, 10),
283 _mm512_castpd_si512(scalar.simdInternal_)));
287 static inline double gmx_simdcall
288 reduceIncr4ReturnSum(double* m, SimdDouble v0, SimdDouble v1, SimdDouble v2, SimdDouble v3)
293 assert(std::size_t(m) % 32 == 0);
295 t0 = _mm512_add_pd(v0.simdInternal_, _mm512_permute_pd(v0.simdInternal_, 0x55));
296 t2 = _mm512_add_pd(v2.simdInternal_, _mm512_permute_pd(v2.simdInternal_, 0x55));
297 t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xAA), v1.simdInternal_,
298 _mm512_permute_pd(v1.simdInternal_, 0x55));
299 t2 = _mm512_mask_add_pd(t2, avx512Int2Mask(0xAA), v3.simdInternal_,
300 _mm512_permute_pd(v3.simdInternal_, 0x55));
301 t0 = _mm512_add_pd(t0, _mm512_shuffle_f64x2(t0, t0, 0x4E));
302 t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xF0), t2, _mm512_shuffle_f64x2(t2, t2, 0x4E));
303 t0 = _mm512_add_pd(t0, _mm512_shuffle_f64x2(t0, t0, 0xB1));
304 t0 = _mm512_mask_shuffle_f64x2(t0, avx512Int2Mask(0x0C), t0, t0, 0xEE);
306 t3 = _mm512_castpd512_pd256(t0);
307 t4 = _mm256_load_pd(m);
308 t4 = _mm256_add_pd(t4, t3);
309 _mm256_store_pd(m, t4);
311 t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0x4E));
312 t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
314 return _mm_cvtsd_f64(_mm512_castpd512_pd128(t0));
317 static inline SimdDouble gmx_simdcall loadDualHsimd(const double* m0, const double* m1)
319 assert(std::size_t(m0) % 32 == 0);
320 assert(std::size_t(m1) % 32 == 0);
322 return { _mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_load_pd(m0)), _mm256_load_pd(m1), 1) };
325 static inline SimdDouble gmx_simdcall loadDuplicateHsimd(const double* m)
327 assert(std::size_t(m) % 32 == 0);
329 return { _mm512_broadcast_f64x4(_mm256_load_pd(m)) };
332 static inline SimdDouble gmx_simdcall loadU1DualHsimd(const double* m)
334 return { _mm512_insertf64x4(_mm512_broadcastsd_pd(_mm_load_sd(m)),
335 _mm256_broadcastsd_pd(_mm_load_sd(m + 1)), 1) };
339 static inline void gmx_simdcall storeDualHsimd(double* m0, double* m1, SimdDouble a)
341 assert(std::size_t(m0) % 32 == 0);
342 assert(std::size_t(m1) % 32 == 0);
344 _mm256_store_pd(m0, _mm512_castpd512_pd256(a.simdInternal_));
345 _mm256_store_pd(m1, _mm512_extractf64x4_pd(a.simdInternal_, 1));
348 static inline void gmx_simdcall incrDualHsimd(double* m0, double* m1, SimdDouble a)
350 assert(std::size_t(m0) % 32 == 0);
351 assert(std::size_t(m1) % 32 == 0);
356 x = _mm256_load_pd(m0);
357 x = _mm256_add_pd(x, _mm512_castpd512_pd256(a.simdInternal_));
358 _mm256_store_pd(m0, x);
361 x = _mm256_load_pd(m1);
362 x = _mm256_add_pd(x, _mm512_extractf64x4_pd(a.simdInternal_, 1));
363 _mm256_store_pd(m1, x);
366 static inline void gmx_simdcall decr3Hsimd(double* m, SimdDouble a0, SimdDouble a1, SimdDouble a2)
369 decrHsimd(m + GMX_SIMD_DOUBLE_WIDTH / 2, a1);
370 decrHsimd(m + GMX_SIMD_DOUBLE_WIDTH, a2);
374 static inline void gmx_simdcall gatherLoadTransposeHsimd(const double* base0,
376 const std::int32_t offset[],
384 assert(std::size_t(offset) % 16 == 0);
385 assert(std::size_t(base0) % 16 == 0);
386 assert(std::size_t(base1) % 16 == 0);
388 idx0 = _mm_load_si128(reinterpret_cast<const __m128i*>(offset));
390 static_assert(align == 2 || align == 4, "If more are needed use fastMultiply");
391 idx0 = _mm_slli_epi32(idx0, align == 2 ? 1 : 2);
393 idx1 = _mm_add_epi32(idx0, _mm_set1_epi32(1));
395 idx = _mm256_inserti128_si256(_mm256_castsi128_si256(idx0), idx1, 1);
397 constexpr size_t scale = sizeof(double);
398 tmp1 = _mm512_i32gather_pd(idx, base0, scale); // TODO: Might be faster to use invidual loads
399 tmp2 = _mm512_i32gather_pd(idx, base1, scale);
401 v0->simdInternal_ = _mm512_shuffle_f64x2(tmp1, tmp2, 0x44);
402 v1->simdInternal_ = _mm512_shuffle_f64x2(tmp1, tmp2, 0xEE);
405 static inline double gmx_simdcall reduceIncr4ReturnSumHsimd(double* m, SimdDouble v0, SimdDouble v1)
410 assert(std::size_t(m) % 32 == 0);
412 t0 = _mm512_add_pd(v0.simdInternal_, _mm512_permutex_pd(v0.simdInternal_, 0x4E));
413 t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xCC), v1.simdInternal_,
414 _mm512_permutex_pd(v1.simdInternal_, 0x4E));
415 t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
416 t0 = _mm512_mask_shuffle_f64x2(t0, avx512Int2Mask(0xAA), t0, t0, 0xEE);
418 t2 = _mm512_castpd512_pd256(t0);
419 t3 = _mm256_load_pd(m);
420 t3 = _mm256_add_pd(t3, t2);
421 _mm256_store_pd(m, t3);
423 t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0x4E));
424 t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
426 return _mm_cvtsd_f64(_mm512_castpd512_pd128(t0));
429 static inline SimdDouble gmx_simdcall loadU4NOffset(const double* m, int offset)
431 return { _mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_loadu_pd(m)),
432 _mm256_loadu_pd(m + offset), 1) };
437 #endif // GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H