Implement changes for CMake policy 0068
[alexxy/gromacs.git] / src / gromacs / simd / impl_x86_avx_512 / impl_x86_avx_512_util_float.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014,2015,2016,2017, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 #ifndef GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H
37 #define GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H
38
39 #include "config.h"
40
41 #include <cassert>
42 #include <cstdint>
43
44 #include <immintrin.h>
45
46 #include "gromacs/utility/basedefinitions.h"
47
48 #include "impl_x86_avx_512_general.h"
49 #include "impl_x86_avx_512_simd_float.h"
50
51 namespace gmx
52 {
53
54 static const int c_simdBestPairAlignmentFloat = 2;
55
56 namespace
57 {
58 // Multiply function optimized for powers of 2, for which it is done by
59 // shifting. Currently up to 8 is accelerated. Could be accelerated for any
60 // number with a constexpr log2 function.
61 template<int n>
62 SimdFInt32 fastMultiply(SimdFInt32 x)
63 {
64     if (n == 2)
65     {
66         return _mm512_slli_epi32(x.simdInternal_, 1);
67     }
68     else if (n == 4)
69     {
70         return _mm512_slli_epi32(x.simdInternal_, 2);
71     }
72     else if (n == 8)
73     {
74         return _mm512_slli_epi32(x.simdInternal_, 3);
75     }
76     else
77     {
78         return x * n;
79     }
80 }
81
82 template<int align>
83 static inline void gmx_simdcall
84 gatherLoadBySimdIntTranspose(const float *, SimdFInt32)
85 {
86     //Nothing to do. Termination of recursion.
87 }
88 }
89
90 template <int align, typename ... Targs>
91 static inline void gmx_simdcall
92 gatherLoadBySimdIntTranspose(const float *base, SimdFInt32 offset, SimdFloat *v, Targs... Fargs)
93 {
94     // For align 1 or 2: No multiplication of offset is needed
95     if (align > 2)
96     {
97         offset = fastMultiply<align>(offset);
98     }
99     // For align 2: Scale of 2*sizeof(float) is used (maximum supported scale)
100     constexpr int align_ = (align > 2) ? 1 : align;
101     v->simdInternal_ = _mm512_i32gather_ps(offset.simdInternal_, base, sizeof(float)*align_);
102     // Gather remaining elements. Avoid extra multiplication (new align is 1 or 2).
103     gatherLoadBySimdIntTranspose<align_>(base+1, offset, Fargs ...);
104 }
105
106 template <int align, typename ... Targs>
107 static inline void gmx_simdcall
108 gatherLoadUBySimdIntTranspose(const float *base, SimdFInt32 offset, Targs... Fargs)
109 {
110     gatherLoadBySimdIntTranspose<align>(base, offset, Fargs ...);
111 }
112
113 template <int align, typename ... Targs>
114 static inline void gmx_simdcall
115 gatherLoadTranspose(const float *base, const std::int32_t offset[], Targs... Fargs)
116 {
117     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdFInt32Tag()), Fargs ...);
118 }
119
120 template <int align, typename ... Targs>
121 static inline void gmx_simdcall
122 gatherLoadUTranspose(const float *base, const std::int32_t offset[], Targs... Fargs)
123 {
124     gatherLoadTranspose<align>(base, offset, Fargs ...);
125 }
126
127 template <int align>
128 static inline void gmx_simdcall
129 transposeScatterStoreU(float *              base,
130                        const std::int32_t   offset[],
131                        SimdFloat            v0,
132                        SimdFloat            v1,
133                        SimdFloat            v2)
134 {
135     SimdFInt32 simdoffset = simdLoad(offset, SimdFInt32Tag());
136     if (align > 2)
137     {
138         simdoffset = fastMultiply<align>(simdoffset);
139     }
140     constexpr size_t scale = (align > 2) ? sizeof(float) : sizeof(float) * align;
141
142     _mm512_i32scatter_ps(base,       simdoffset.simdInternal_, v0.simdInternal_, scale);
143     _mm512_i32scatter_ps(&(base[1]), simdoffset.simdInternal_, v1.simdInternal_, scale);
144     _mm512_i32scatter_ps(&(base[2]), simdoffset.simdInternal_, v2.simdInternal_, scale);
145 }
146
147 template <int align>
148 static inline void gmx_simdcall
149 transposeScatterIncrU(float *              base,
150                       const std::int32_t   offset[],
151                       SimdFloat            v0,
152                       SimdFloat            v1,
153                       SimdFloat            v2)
154 {
155     __m512 t[4], t5, t6, t7, t8;
156     int    i;
157     alignas(GMX_SIMD_ALIGNMENT) std::int32_t    o[16];
158     store(o, fastMultiply<align>(simdLoad(offset, SimdFInt32Tag())));
159     if (align < 4)
160     {
161         t5   = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
162         t6   = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
163         t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
164         t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
165         t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
166         t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
167         for (i = 0; i < 4; i++)
168         {
169             _mm512_mask_storeu_ps(base + o[i], avx512Int2Mask(7), _mm512_castps128_ps512(
170                                           _mm_add_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i]))));
171             _mm512_mask_storeu_ps(base + o[ 4 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
172                                           _mm_add_ps(_mm_loadu_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1))));
173             _mm512_mask_storeu_ps(base + o[ 8 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
174                                           _mm_add_ps(_mm_loadu_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2))));
175             _mm512_mask_storeu_ps(base + o[12 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
176                                           _mm_add_ps(_mm_loadu_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3))));
177         }
178     }
179     else
180     {
181         //One could use shuffle here too if it is OK to overwrite the padded elements for alignment
182         t5    = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
183         t6    = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
184         t7    = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
185         t8    = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
186         t[0]  = _mm512_unpacklo_ps(t5, t7);                             // x0 y0 z0  0 | x4 y4 z4 0
187         t[1]  = _mm512_unpackhi_ps(t5, t7);                             // x1 y1 z1  0 | x5 y5 z5 0
188         t[2]  = _mm512_unpacklo_ps(t6, t8);                             // x2 y2 z2  0 | x6 y6 z6 0
189         t[3]  = _mm512_unpackhi_ps(t6, t8);                             // x3 y3 z3  0 | x7 y7 z7 0
190         if (align % 4 == 0)
191         {
192             for (i = 0; i < 4; i++)
193             {
194                 _mm_store_ps(base + o[i], _mm_add_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
195                 _mm_store_ps(base + o[ 4 + i],
196                              _mm_add_ps(_mm_load_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
197                 _mm_store_ps(base + o[ 8 + i],
198                              _mm_add_ps(_mm_load_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
199                 _mm_store_ps(base + o[12 + i],
200                              _mm_add_ps(_mm_load_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
201             }
202         }
203         else
204         {
205             for (i = 0; i < 4; i++)
206             {
207                 _mm_storeu_ps(base + o[i], _mm_add_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
208                 _mm_storeu_ps(base + o[ 4 + i],
209                               _mm_add_ps(_mm_loadu_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
210                 _mm_storeu_ps(base + o[ 8 + i],
211                               _mm_add_ps(_mm_loadu_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
212                 _mm_storeu_ps(base + o[12 + i],
213                               _mm_add_ps(_mm_loadu_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
214             }
215         }
216     }
217 }
218
219 template <int align>
220 static inline void gmx_simdcall
221 transposeScatterDecrU(float *              base,
222                       const std::int32_t   offset[],
223                       SimdFloat            v0,
224                       SimdFloat            v1,
225                       SimdFloat            v2)
226 {
227     __m512 t[4], t5, t6, t7, t8;
228     int    i;
229     alignas(GMX_SIMD_ALIGNMENT) std::int32_t    o[16];
230     store(o, fastMultiply<align>(simdLoad(offset, SimdFInt32Tag())));
231     if (align < 4)
232     {
233         t5   = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
234         t6   = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
235         t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
236         t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
237         t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
238         t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
239         for (i = 0; i < 4; i++)
240         {
241             _mm512_mask_storeu_ps(base + o[i], avx512Int2Mask(7), _mm512_castps128_ps512(
242                                           _mm_sub_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i]))));
243             _mm512_mask_storeu_ps(base + o[ 4 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
244                                           _mm_sub_ps(_mm_loadu_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1))));
245             _mm512_mask_storeu_ps(base + o[ 8 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
246                                           _mm_sub_ps(_mm_loadu_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2))));
247             _mm512_mask_storeu_ps(base + o[12 + i], avx512Int2Mask(7), _mm512_castps128_ps512(
248                                           _mm_sub_ps(_mm_loadu_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3))));
249         }
250     }
251     else
252     {
253         //One could use shuffle here too if it is OK to overwrite the padded elements for alignment
254         t5    = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
255         t6    = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
256         t7    = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
257         t8    = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
258         t[0]  = _mm512_unpacklo_ps(t5, t7);                             // x0 y0 z0  0 | x4 y4 z4 0
259         t[1]  = _mm512_unpackhi_ps(t5, t7);                             // x1 y1 z1  0 | x5 y5 z5 0
260         t[2]  = _mm512_unpacklo_ps(t6, t8);                             // x2 y2 z2  0 | x6 y6 z6 0
261         t[3]  = _mm512_unpackhi_ps(t6, t8);                             // x3 y3 z3  0 | x7 y7 z7 0
262         if (align % 4 == 0)
263         {
264             for (i = 0; i < 4; i++)
265             {
266                 _mm_store_ps(base + o[i], _mm_sub_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
267                 _mm_store_ps(base + o[ 4 + i],
268                              _mm_sub_ps(_mm_load_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
269                 _mm_store_ps(base + o[ 8 + i],
270                              _mm_sub_ps(_mm_load_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
271                 _mm_store_ps(base + o[12 + i],
272                              _mm_sub_ps(_mm_load_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
273             }
274         }
275         else
276         {
277             for (i = 0; i < 4; i++)
278             {
279                 _mm_storeu_ps(base + o[i], _mm_sub_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
280                 _mm_storeu_ps(base + o[ 4 + i],
281                               _mm_sub_ps(_mm_loadu_ps(base + o[ 4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
282                 _mm_storeu_ps(base + o[ 8 + i],
283                               _mm_sub_ps(_mm_loadu_ps(base + o[ 8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
284                 _mm_storeu_ps(base + o[12 + i],
285                               _mm_sub_ps(_mm_loadu_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
286             }
287         }
288     }
289 }
290
291 static inline void gmx_simdcall
292 expandScalarsToTriplets(SimdFloat    scalar,
293                         SimdFloat *  triplets0,
294                         SimdFloat *  triplets1,
295                         SimdFloat *  triplets2)
296 {
297     triplets0->simdInternal_ = _mm512_permutexvar_ps(_mm512_set_epi32(5, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1, 0, 0, 0),
298                                                      scalar.simdInternal_);
299     triplets1->simdInternal_ = _mm512_permutexvar_ps(_mm512_set_epi32(10, 10, 9, 9, 9, 8, 8, 8, 7, 7, 7, 6, 6, 6, 5, 5),
300                                                      scalar.simdInternal_);
301     triplets2->simdInternal_ = _mm512_permutexvar_ps(_mm512_set_epi32(15, 15, 15, 14, 14, 14, 13, 13, 13, 12, 12, 12, 11, 11, 11, 10),
302                                                      scalar.simdInternal_);
303 }
304
305
306 static inline float gmx_simdcall
307 reduceIncr4ReturnSum(float *    m,
308                      SimdFloat  v0,
309                      SimdFloat  v1,
310                      SimdFloat  v2,
311                      SimdFloat  v3)
312 {
313     __m512 t0, t1, t2;
314     __m128 t3, t4;
315
316     assert(std::size_t(m) % 16 == 0);
317
318     t0 = _mm512_add_ps(v0.simdInternal_, _mm512_permute_ps(v0.simdInternal_, 0x4E));
319     t0 = _mm512_mask_add_ps(t0, avx512Int2Mask(0xCCCC), v2.simdInternal_, _mm512_permute_ps(v2.simdInternal_, 0x4E));
320     t1 = _mm512_add_ps(v1.simdInternal_, _mm512_permute_ps(v1.simdInternal_, 0x4E));
321     t1 = _mm512_mask_add_ps(t1, avx512Int2Mask(0xCCCC), v3.simdInternal_, _mm512_permute_ps(v3.simdInternal_, 0x4E));
322     t2 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0xB1));
323     t2 = _mm512_mask_add_ps(t2, avx512Int2Mask(0xAAAA), t1, _mm512_permute_ps(t1, 0xB1));
324
325     t2 = _mm512_add_ps(t2, _mm512_shuffle_f32x4(t2, t2, 0x4E));
326     t2 = _mm512_add_ps(t2, _mm512_shuffle_f32x4(t2, t2, 0xB1));
327
328     t3 = _mm512_castps512_ps128(t2);
329     t4 = _mm_load_ps(m);
330     t4 = _mm_add_ps(t4, t3);
331     _mm_store_ps(m, t4);
332
333     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0x4E));
334     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0xB1));
335
336     return _mm_cvtss_f32(t3);
337
338 }
339
340 static inline SimdFloat gmx_simdcall
341 loadDualHsimd(const float * m0,
342               const float * m1)
343 {
344     assert(std::size_t(m0) % 32 == 0);
345     assert(std::size_t(m1) % 32 == 0);
346
347     return {
348                _mm512_castpd_ps(_mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_load_pd(reinterpret_cast<const double*>(m0))),
349                                                    _mm256_load_pd(reinterpret_cast<const double*>(m1)), 1))
350     };
351 }
352
353 static inline SimdFloat gmx_simdcall
354 loadDuplicateHsimd(const float * m)
355 {
356     assert(std::size_t(m) % 32 == 0);
357     return {
358                _mm512_castpd_ps(_mm512_broadcast_f64x4(_mm256_load_pd(reinterpret_cast<const double*>(m))))
359     };
360 }
361
362 static inline SimdFloat gmx_simdcall
363 loadU1DualHsimd(const float * m)
364 {
365     return {
366                _mm512_shuffle_f32x4(_mm512_broadcastss_ps(_mm_load_ss(m)),
367                                     _mm512_broadcastss_ps(_mm_load_ss(m+1)), 0x44)
368     };
369 }
370
371
372 static inline void gmx_simdcall
373 storeDualHsimd(float *     m0,
374                float *     m1,
375                SimdFloat   a)
376 {
377     assert(std::size_t(m0) % 32 == 0);
378     assert(std::size_t(m1) % 32 == 0);
379
380     _mm256_store_ps(m0, _mm512_castps512_ps256(a.simdInternal_));
381     _mm256_store_pd(reinterpret_cast<double*>(m1), _mm512_extractf64x4_pd(_mm512_castps_pd(a.simdInternal_), 1));
382 }
383
384 static inline void gmx_simdcall
385 incrDualHsimd(float *     m0,
386               float *     m1,
387               SimdFloat   a)
388 {
389     assert(std::size_t(m0) % 32 == 0);
390     assert(std::size_t(m1) % 32 == 0);
391
392     __m256 x;
393
394     // Lower half
395     x = _mm256_load_ps(m0);
396     x = _mm256_add_ps(x, _mm512_castps512_ps256(a.simdInternal_));
397     _mm256_store_ps(m0, x);
398
399     // Upper half
400     x = _mm256_load_ps(m1);
401     x = _mm256_add_ps(x, _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a.simdInternal_), 1)));
402     _mm256_store_ps(m1, x);
403 }
404
405 static inline void gmx_simdcall
406 decrHsimd(float *    m,
407           SimdFloat  a)
408 {
409     __m256 t;
410
411     assert(std::size_t(m) % 32 == 0);
412
413     a.simdInternal_ = _mm512_add_ps(a.simdInternal_, _mm512_shuffle_f32x4(a.simdInternal_, a.simdInternal_, 0xEE));
414     t               = _mm256_load_ps(m);
415     t               = _mm256_sub_ps(t, _mm512_castps512_ps256(a.simdInternal_));
416     _mm256_store_ps(m, t);
417 }
418
419
420 template <int align>
421 static inline void gmx_simdcall
422 gatherLoadTransposeHsimd(const float *        base0,
423                          const float *        base1,
424                          const std::int32_t   offset[],
425                          SimdFloat *          v0,
426                          SimdFloat *          v1)
427 {
428     __m256i idx;
429     __m512  tmp1, tmp2;
430
431     assert(std::size_t(offset) % 32 == 0);
432     assert(std::size_t(base0) % 8 == 0);
433     assert(std::size_t(base1) % 8 == 0);
434
435     idx = _mm256_load_si256(reinterpret_cast<const __m256i*>(offset));
436
437     static_assert(align == 2 || align == 4, "If more are needed use fastMultiply");
438     if (align == 4)
439     {
440         idx = _mm256_slli_epi32(idx, 1);
441     }
442
443     tmp1 = _mm512_castpd_ps(_mm512_i32gather_pd(idx, reinterpret_cast<const double *>(base0), sizeof(double)));
444     tmp2 = _mm512_castpd_ps(_mm512_i32gather_pd(idx, reinterpret_cast<const double *>(base1), sizeof(double)));
445
446     v0->simdInternal_ = _mm512_mask_moveldup_ps(tmp1, 0xAAAA, tmp2);
447     v1->simdInternal_ = _mm512_mask_movehdup_ps(tmp2, 0x5555, tmp1);
448
449     v0->simdInternal_ = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), v0->simdInternal_);
450     v1->simdInternal_ = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), v1->simdInternal_);
451 }
452
453 static inline float gmx_simdcall
454 reduceIncr4ReturnSumHsimd(float *     m,
455                           SimdFloat   v0,
456                           SimdFloat   v1)
457 {
458     __m512 t0, t1;
459     __m128 t2, t3;
460
461     assert(std::size_t(m) % 16 == 0);
462
463     t0 = _mm512_shuffle_f32x4(v0.simdInternal_, v1.simdInternal_, 0x88);
464     t1 = _mm512_shuffle_f32x4(v0.simdInternal_, v1.simdInternal_, 0xDD);
465     t0 = _mm512_add_ps(t0, t1);
466     t0 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0x4E));
467     t0 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0xB1));
468     t0 = _mm512_maskz_compress_ps(avx512Int2Mask(0x1111), t0);
469
470     t3 = _mm512_castps512_ps128(t0);
471     t2 = _mm_load_ps(m);
472     t2 = _mm_add_ps(t2, t3);
473     _mm_store_ps(m, t2);
474
475     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0x4E));
476     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0xB1));
477
478     return _mm_cvtss_f32(t3);
479 }
480
481 static inline SimdFloat gmx_simdcall
482 loadUNDuplicate4(const float* f)
483 {
484     return {
485                _mm512_permute_ps(_mm512_maskz_expandloadu_ps(0x1111, f), 0)
486     };
487 }
488
489 static inline SimdFloat gmx_simdcall
490 load4DuplicateN(const float* f)
491 {
492     return {
493                _mm512_broadcast_f32x4(_mm_load_ps(f))
494     };
495 }
496
497 static inline SimdFloat gmx_simdcall
498 loadU4NOffset(const float* f, int offset)
499 {
500     const __m256i idx = _mm256_setr_epi32(0, 0, 1, 1, 2, 2, 3, 3);
501     const __m256i gdx = _mm256_add_epi32(_mm256_setr_epi32(0, 2, 0, 2, 0, 2, 0, 2),
502                                          _mm256_mullo_epi32(idx, _mm256_set1_epi32(offset)));
503     return {
504                _mm512_castpd_ps(_mm512_i32gather_pd(gdx, reinterpret_cast<const double*>(f), sizeof(float)))
505     };
506 }
507
508 }      // namespace gmx
509
510 #endif // GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H