Apply clang-format to source tree
[alexxy/gromacs.git] / src / gromacs / simd / impl_x86_avx_512 / impl_x86_avx_512_util_float.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014,2015,2016,2017,2018,2019, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 #ifndef GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H
37 #define GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H
38
39 #include "config.h"
40
41 #include <cassert>
42 #include <cstdint>
43
44 #include <immintrin.h>
45
46 #include "gromacs/utility/basedefinitions.h"
47
48 #include "impl_x86_avx_512_general.h"
49 #include "impl_x86_avx_512_simd_float.h"
50
51 namespace gmx
52 {
53
54 static const int c_simdBestPairAlignmentFloat = 2;
55
56 namespace
57 {
58 // Multiply function optimized for powers of 2, for which it is done by
59 // shifting. Currently up to 8 is accelerated. Could be accelerated for any
60 // number with a constexpr log2 function.
61 template<int n>
62 SimdFInt32 fastMultiply(SimdFInt32 x)
63 {
64     if (n == 2)
65     {
66         return _mm512_slli_epi32(x.simdInternal_, 1);
67     }
68     else if (n == 4)
69     {
70         return _mm512_slli_epi32(x.simdInternal_, 2);
71     }
72     else if (n == 8)
73     {
74         return _mm512_slli_epi32(x.simdInternal_, 3);
75     }
76     else
77     {
78         return x * n;
79     }
80 }
81
82 template<int align>
83 static inline void gmx_simdcall gatherLoadBySimdIntTranspose(const float*, SimdFInt32)
84 {
85     // Nothing to do. Termination of recursion.
86 }
87 } // namespace
88
89 template<int align, typename... Targs>
90 static inline void gmx_simdcall
91                    gatherLoadBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v, Targs... Fargs)
92 {
93     // For align 1 or 2: No multiplication of offset is needed
94     if (align > 2)
95     {
96         offset = fastMultiply<align>(offset);
97     }
98     // For align 2: Scale of 2*sizeof(float) is used (maximum supported scale)
99     constexpr int align_ = (align > 2) ? 1 : align;
100     v->simdInternal_     = _mm512_i32gather_ps(offset.simdInternal_, base, sizeof(float) * align_);
101     // Gather remaining elements. Avoid extra multiplication (new align is 1 or 2).
102     gatherLoadBySimdIntTranspose<align_>(base + 1, offset, Fargs...);
103 }
104
105 template<int align, typename... Targs>
106 static inline void gmx_simdcall
107                    gatherLoadUBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v, Targs... Fargs)
108 {
109     gatherLoadBySimdIntTranspose<align>(base, offset, v, Fargs...);
110 }
111
112 template<int align, typename... Targs>
113 static inline void gmx_simdcall
114                    gatherLoadTranspose(const float* base, const std::int32_t offset[], SimdFloat* v, Targs... Fargs)
115 {
116     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdFInt32Tag()), v, Fargs...);
117 }
118
119 template<int align, typename... Targs>
120 static inline void gmx_simdcall
121                    gatherLoadUTranspose(const float* base, const std::int32_t offset[], SimdFloat* v, Targs... Fargs)
122 {
123     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdFInt32Tag()), v, Fargs...);
124 }
125
126 template<int align>
127 static inline void gmx_simdcall
128                    transposeScatterStoreU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
129 {
130     SimdFInt32 simdoffset = simdLoad(offset, SimdFInt32Tag());
131     if (align > 2)
132     {
133         simdoffset = fastMultiply<align>(simdoffset);
134     }
135     constexpr size_t scale = (align > 2) ? sizeof(float) : sizeof(float) * align;
136
137     _mm512_i32scatter_ps(base, simdoffset.simdInternal_, v0.simdInternal_, scale);
138     _mm512_i32scatter_ps(&(base[1]), simdoffset.simdInternal_, v1.simdInternal_, scale);
139     _mm512_i32scatter_ps(&(base[2]), simdoffset.simdInternal_, v2.simdInternal_, scale);
140 }
141
142 template<int align>
143 static inline void gmx_simdcall
144                    transposeScatterIncrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
145 {
146     __m512                                   t[4], t5, t6, t7, t8;
147     int                                      i;
148     alignas(GMX_SIMD_ALIGNMENT) std::int32_t o[16];
149     store(o, fastMultiply<align>(simdLoad(offset, SimdFInt32Tag())));
150     if (align < 4)
151     {
152         t5   = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
153         t6   = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
154         t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
155         t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
156         t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
157         t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
158         for (i = 0; i < 4; i++)
159         {
160             _mm512_mask_storeu_ps(base + o[i], avx512Int2Mask(7),
161                                   _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[i]),
162                                                                     _mm512_castps512_ps128(t[i]))));
163             _mm512_mask_storeu_ps(base + o[4 + i], avx512Int2Mask(7),
164                                   _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[4 + i]),
165                                                                     _mm512_extractf32x4_ps(t[i], 1))));
166             _mm512_mask_storeu_ps(base + o[8 + i], avx512Int2Mask(7),
167                                   _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[8 + i]),
168                                                                     _mm512_extractf32x4_ps(t[i], 2))));
169             _mm512_mask_storeu_ps(base + o[12 + i], avx512Int2Mask(7),
170                                   _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[12 + i]),
171                                                                     _mm512_extractf32x4_ps(t[i], 3))));
172         }
173     }
174     else
175     {
176         // One could use shuffle here too if it is OK to overwrite the padded elements for alignment
177         t5   = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
178         t6   = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
179         t7   = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
180         t8   = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
181         t[0] = _mm512_unpacklo_ps(t5, t7); // x0 y0 z0  0 | x4 y4 z4 0
182         t[1] = _mm512_unpackhi_ps(t5, t7); // x1 y1 z1  0 | x5 y5 z5 0
183         t[2] = _mm512_unpacklo_ps(t6, t8); // x2 y2 z2  0 | x6 y6 z6 0
184         t[3] = _mm512_unpackhi_ps(t6, t8); // x3 y3 z3  0 | x7 y7 z7 0
185         if (align % 4 == 0)
186         {
187             for (i = 0; i < 4; i++)
188             {
189                 _mm_store_ps(base + o[i],
190                              _mm_add_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
191                 _mm_store_ps(base + o[4 + i],
192                              _mm_add_ps(_mm_load_ps(base + o[4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
193                 _mm_store_ps(base + o[8 + i],
194                              _mm_add_ps(_mm_load_ps(base + o[8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
195                 _mm_store_ps(base + o[12 + i], _mm_add_ps(_mm_load_ps(base + o[12 + i]),
196                                                           _mm512_extractf32x4_ps(t[i], 3)));
197             }
198         }
199         else
200         {
201             for (i = 0; i < 4; i++)
202             {
203                 _mm_storeu_ps(base + o[i],
204                               _mm_add_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
205                 _mm_storeu_ps(base + o[4 + i], _mm_add_ps(_mm_loadu_ps(base + o[4 + i]),
206                                                           _mm512_extractf32x4_ps(t[i], 1)));
207                 _mm_storeu_ps(base + o[8 + i], _mm_add_ps(_mm_loadu_ps(base + o[8 + i]),
208                                                           _mm512_extractf32x4_ps(t[i], 2)));
209                 _mm_storeu_ps(base + o[12 + i], _mm_add_ps(_mm_loadu_ps(base + o[12 + i]),
210                                                            _mm512_extractf32x4_ps(t[i], 3)));
211             }
212         }
213     }
214 }
215
216 template<int align>
217 static inline void gmx_simdcall
218                    transposeScatterDecrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
219 {
220     __m512                                   t[4], t5, t6, t7, t8;
221     int                                      i;
222     alignas(GMX_SIMD_ALIGNMENT) std::int32_t o[16];
223     store(o, fastMultiply<align>(simdLoad(offset, SimdFInt32Tag())));
224     if (align < 4)
225     {
226         t5   = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
227         t6   = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
228         t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
229         t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
230         t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
231         t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
232         for (i = 0; i < 4; i++)
233         {
234             _mm512_mask_storeu_ps(base + o[i], avx512Int2Mask(7),
235                                   _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[i]),
236                                                                     _mm512_castps512_ps128(t[i]))));
237             _mm512_mask_storeu_ps(base + o[4 + i], avx512Int2Mask(7),
238                                   _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[4 + i]),
239                                                                     _mm512_extractf32x4_ps(t[i], 1))));
240             _mm512_mask_storeu_ps(base + o[8 + i], avx512Int2Mask(7),
241                                   _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[8 + i]),
242                                                                     _mm512_extractf32x4_ps(t[i], 2))));
243             _mm512_mask_storeu_ps(base + o[12 + i], avx512Int2Mask(7),
244                                   _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[12 + i]),
245                                                                     _mm512_extractf32x4_ps(t[i], 3))));
246         }
247     }
248     else
249     {
250         // One could use shuffle here too if it is OK to overwrite the padded elements for alignment
251         t5   = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
252         t6   = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
253         t7   = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
254         t8   = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
255         t[0] = _mm512_unpacklo_ps(t5, t7); // x0 y0 z0  0 | x4 y4 z4 0
256         t[1] = _mm512_unpackhi_ps(t5, t7); // x1 y1 z1  0 | x5 y5 z5 0
257         t[2] = _mm512_unpacklo_ps(t6, t8); // x2 y2 z2  0 | x6 y6 z6 0
258         t[3] = _mm512_unpackhi_ps(t6, t8); // x3 y3 z3  0 | x7 y7 z7 0
259         if (align % 4 == 0)
260         {
261             for (i = 0; i < 4; i++)
262             {
263                 _mm_store_ps(base + o[i],
264                              _mm_sub_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
265                 _mm_store_ps(base + o[4 + i],
266                              _mm_sub_ps(_mm_load_ps(base + o[4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
267                 _mm_store_ps(base + o[8 + i],
268                              _mm_sub_ps(_mm_load_ps(base + o[8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
269                 _mm_store_ps(base + o[12 + i], _mm_sub_ps(_mm_load_ps(base + o[12 + i]),
270                                                           _mm512_extractf32x4_ps(t[i], 3)));
271             }
272         }
273         else
274         {
275             for (i = 0; i < 4; i++)
276             {
277                 _mm_storeu_ps(base + o[i],
278                               _mm_sub_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
279                 _mm_storeu_ps(base + o[4 + i], _mm_sub_ps(_mm_loadu_ps(base + o[4 + i]),
280                                                           _mm512_extractf32x4_ps(t[i], 1)));
281                 _mm_storeu_ps(base + o[8 + i], _mm_sub_ps(_mm_loadu_ps(base + o[8 + i]),
282                                                           _mm512_extractf32x4_ps(t[i], 2)));
283                 _mm_storeu_ps(base + o[12 + i], _mm_sub_ps(_mm_loadu_ps(base + o[12 + i]),
284                                                            _mm512_extractf32x4_ps(t[i], 3)));
285             }
286         }
287     }
288 }
289
290 static inline void gmx_simdcall expandScalarsToTriplets(SimdFloat  scalar,
291                                                         SimdFloat* triplets0,
292                                                         SimdFloat* triplets1,
293                                                         SimdFloat* triplets2)
294 {
295     triplets0->simdInternal_ = _mm512_permutexvar_ps(
296             _mm512_set_epi32(5, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1, 0, 0, 0), scalar.simdInternal_);
297     triplets1->simdInternal_ = _mm512_permutexvar_ps(
298             _mm512_set_epi32(10, 10, 9, 9, 9, 8, 8, 8, 7, 7, 7, 6, 6, 6, 5, 5), scalar.simdInternal_);
299     triplets2->simdInternal_ = _mm512_permutexvar_ps(
300             _mm512_set_epi32(15, 15, 15, 14, 14, 14, 13, 13, 13, 12, 12, 12, 11, 11, 11, 10),
301             scalar.simdInternal_);
302 }
303
304
305 static inline float gmx_simdcall reduceIncr4ReturnSum(float* m, SimdFloat v0, SimdFloat v1, SimdFloat v2, SimdFloat v3)
306 {
307     __m512 t0, t1, t2;
308     __m128 t3, t4;
309
310     assert(std::size_t(m) % 16 == 0);
311
312     t0 = _mm512_add_ps(v0.simdInternal_, _mm512_permute_ps(v0.simdInternal_, 0x4E));
313     t0 = _mm512_mask_add_ps(t0, avx512Int2Mask(0xCCCC), v2.simdInternal_,
314                             _mm512_permute_ps(v2.simdInternal_, 0x4E));
315     t1 = _mm512_add_ps(v1.simdInternal_, _mm512_permute_ps(v1.simdInternal_, 0x4E));
316     t1 = _mm512_mask_add_ps(t1, avx512Int2Mask(0xCCCC), v3.simdInternal_,
317                             _mm512_permute_ps(v3.simdInternal_, 0x4E));
318     t2 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0xB1));
319     t2 = _mm512_mask_add_ps(t2, avx512Int2Mask(0xAAAA), t1, _mm512_permute_ps(t1, 0xB1));
320
321     t2 = _mm512_add_ps(t2, _mm512_shuffle_f32x4(t2, t2, 0x4E));
322     t2 = _mm512_add_ps(t2, _mm512_shuffle_f32x4(t2, t2, 0xB1));
323
324     t3 = _mm512_castps512_ps128(t2);
325     t4 = _mm_load_ps(m);
326     t4 = _mm_add_ps(t4, t3);
327     _mm_store_ps(m, t4);
328
329     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0x4E));
330     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0xB1));
331
332     return _mm_cvtss_f32(t3);
333 }
334
335 static inline SimdFloat gmx_simdcall loadDualHsimd(const float* m0, const float* m1)
336 {
337     assert(std::size_t(m0) % 32 == 0);
338     assert(std::size_t(m1) % 32 == 0);
339
340     return { _mm512_castpd_ps(_mm512_insertf64x4(
341             _mm512_castpd256_pd512(_mm256_load_pd(reinterpret_cast<const double*>(m0))),
342             _mm256_load_pd(reinterpret_cast<const double*>(m1)), 1)) };
343 }
344
345 static inline SimdFloat gmx_simdcall loadDuplicateHsimd(const float* m)
346 {
347     assert(std::size_t(m) % 32 == 0);
348     return { _mm512_castpd_ps(_mm512_broadcast_f64x4(_mm256_load_pd(reinterpret_cast<const double*>(m)))) };
349 }
350
351 static inline SimdFloat gmx_simdcall loadU1DualHsimd(const float* m)
352 {
353     return { _mm512_shuffle_f32x4(_mm512_broadcastss_ps(_mm_load_ss(m)),
354                                   _mm512_broadcastss_ps(_mm_load_ss(m + 1)), 0x44) };
355 }
356
357
358 static inline void gmx_simdcall storeDualHsimd(float* m0, float* m1, SimdFloat a)
359 {
360     assert(std::size_t(m0) % 32 == 0);
361     assert(std::size_t(m1) % 32 == 0);
362
363     _mm256_store_ps(m0, _mm512_castps512_ps256(a.simdInternal_));
364     _mm256_store_pd(reinterpret_cast<double*>(m1),
365                     _mm512_extractf64x4_pd(_mm512_castps_pd(a.simdInternal_), 1));
366 }
367
368 static inline void gmx_simdcall incrDualHsimd(float* m0, float* m1, SimdFloat a)
369 {
370     assert(std::size_t(m0) % 32 == 0);
371     assert(std::size_t(m1) % 32 == 0);
372
373     __m256 x;
374
375     // Lower half
376     x = _mm256_load_ps(m0);
377     x = _mm256_add_ps(x, _mm512_castps512_ps256(a.simdInternal_));
378     _mm256_store_ps(m0, x);
379
380     // Upper half
381     x = _mm256_load_ps(m1);
382     x = _mm256_add_ps(x, _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a.simdInternal_), 1)));
383     _mm256_store_ps(m1, x);
384 }
385
386 static inline void gmx_simdcall decrHsimd(float* m, SimdFloat a)
387 {
388     __m256 t;
389
390     assert(std::size_t(m) % 32 == 0);
391
392     a.simdInternal_ = _mm512_add_ps(a.simdInternal_,
393                                     _mm512_shuffle_f32x4(a.simdInternal_, a.simdInternal_, 0xEE));
394     t               = _mm256_load_ps(m);
395     t               = _mm256_sub_ps(t, _mm512_castps512_ps256(a.simdInternal_));
396     _mm256_store_ps(m, t);
397 }
398
399
400 template<int align>
401 static inline void gmx_simdcall gatherLoadTransposeHsimd(const float*       base0,
402                                                          const float*       base1,
403                                                          const std::int32_t offset[],
404                                                          SimdFloat*         v0,
405                                                          SimdFloat*         v1)
406 {
407     __m256i idx;
408     __m512  tmp1, tmp2;
409
410     assert(std::size_t(offset) % 32 == 0);
411     assert(std::size_t(base0) % 8 == 0);
412     assert(std::size_t(base1) % 8 == 0);
413
414     idx = _mm256_load_si256(reinterpret_cast<const __m256i*>(offset));
415
416     static_assert(align == 2 || align == 4, "If more are needed use fastMultiply");
417     if (align == 4)
418     {
419         idx = _mm256_slli_epi32(idx, 1);
420     }
421
422     tmp1 = _mm512_castpd_ps(
423             _mm512_i32gather_pd(idx, reinterpret_cast<const double*>(base0), sizeof(double)));
424     tmp2 = _mm512_castpd_ps(
425             _mm512_i32gather_pd(idx, reinterpret_cast<const double*>(base1), sizeof(double)));
426
427     v0->simdInternal_ = _mm512_mask_moveldup_ps(tmp1, 0xAAAA, tmp2);
428     v1->simdInternal_ = _mm512_mask_movehdup_ps(tmp2, 0x5555, tmp1);
429
430     v0->simdInternal_ = _mm512_permutexvar_ps(
431             _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), v0->simdInternal_);
432     v1->simdInternal_ = _mm512_permutexvar_ps(
433             _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), v1->simdInternal_);
434 }
435
436 static inline float gmx_simdcall reduceIncr4ReturnSumHsimd(float* m, SimdFloat v0, SimdFloat v1)
437 {
438     __m512 t0, t1;
439     __m128 t2, t3;
440
441     assert(std::size_t(m) % 16 == 0);
442
443     t0 = _mm512_shuffle_f32x4(v0.simdInternal_, v1.simdInternal_, 0x88);
444     t1 = _mm512_shuffle_f32x4(v0.simdInternal_, v1.simdInternal_, 0xDD);
445     t0 = _mm512_add_ps(t0, t1);
446     t0 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0x4E));
447     t0 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0xB1));
448     t0 = _mm512_maskz_compress_ps(avx512Int2Mask(0x1111), t0);
449
450     t3 = _mm512_castps512_ps128(t0);
451     t2 = _mm_load_ps(m);
452     t2 = _mm_add_ps(t2, t3);
453     _mm_store_ps(m, t2);
454
455     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0x4E));
456     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0xB1));
457
458     return _mm_cvtss_f32(t3);
459 }
460
461 static inline SimdFloat gmx_simdcall loadUNDuplicate4(const float* f)
462 {
463     return { _mm512_permute_ps(_mm512_maskz_expandloadu_ps(0x1111, f), 0) };
464 }
465
466 static inline SimdFloat gmx_simdcall load4DuplicateN(const float* f)
467 {
468     return { _mm512_broadcast_f32x4(_mm_load_ps(f)) };
469 }
470
471 static inline SimdFloat gmx_simdcall loadU4NOffset(const float* f, int offset)
472 {
473     const __m256i idx = _mm256_setr_epi32(0, 0, 1, 1, 2, 2, 3, 3);
474     const __m256i gdx = _mm256_add_epi32(_mm256_setr_epi32(0, 2, 0, 2, 0, 2, 0, 2),
475                                          _mm256_mullo_epi32(idx, _mm256_set1_epi32(offset)));
476     return { _mm512_castpd_ps(_mm512_i32gather_pd(gdx, reinterpret_cast<const double*>(f), sizeof(float))) };
477 }
478
479 } // namespace gmx
480
481 #endif // GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H