8c3f9621a1f123ec28bf99a5a872deb34f620223
[alexxy/gromacs.git] / src / gromacs / simd / impl_x86_avx_512 / impl_x86_avx_512_util_float.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014,2015,2016,2017,2018 by the GROMACS development team.
5  * Copyright (c) 2019,2020, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36
37 #ifndef GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H
38 #define GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H
39
40 #include "config.h"
41
42 #include <cassert>
43 #include <cstdint>
44
45 #include <immintrin.h>
46
47 #include "gromacs/utility/basedefinitions.h"
48
49 #include "impl_x86_avx_512_general.h"
50 #include "impl_x86_avx_512_simd_float.h"
51
52 namespace gmx
53 {
54
55 static const int c_simdBestPairAlignmentFloat = 2;
56
57 namespace
58 {
59 // Multiply function optimized for powers of 2, for which it is done by
60 // shifting. Currently up to 8 is accelerated. Could be accelerated for any
61 // number with a constexpr log2 function.
62 template<int n>
63 static inline SimdFInt32 fastMultiply(SimdFInt32 x)
64 {
65     if (n == 2)
66     {
67         return _mm512_slli_epi32(x.simdInternal_, 1);
68     }
69     else if (n == 4)
70     {
71         return _mm512_slli_epi32(x.simdInternal_, 2);
72     }
73     else if (n == 8)
74     {
75         return _mm512_slli_epi32(x.simdInternal_, 3);
76     }
77     else
78     {
79         return x * n;
80     }
81 }
82
83 template<int align>
84 static inline void gmx_simdcall gatherLoadBySimdIntTranspose(const float*, SimdFInt32)
85 {
86     // Nothing to do. Termination of recursion.
87 }
88
89 /* This is an internal helper function used by decr3Hsimd(...).
90  */
91 inline void gmx_simdcall decrHsimd(float* m, SimdFloat a)
92 {
93     __m256 t;
94
95     assert(std::size_t(m) % 32 == 0);
96
97     a.simdInternal_ = _mm512_add_ps(a.simdInternal_,
98                                     _mm512_shuffle_f32x4(a.simdInternal_, a.simdInternal_, 0xEE));
99     t               = _mm256_load_ps(m);
100     t               = _mm256_sub_ps(t, _mm512_castps512_ps256(a.simdInternal_));
101     _mm256_store_ps(m, t);
102 }
103 } // namespace
104
105 template<int align, typename... Targs>
106 static inline void gmx_simdcall
107                    gatherLoadBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v, Targs... Fargs)
108 {
109     // For align 1 or 2: No multiplication of offset is needed
110     if (align > 2)
111     {
112         offset = fastMultiply<align>(offset);
113     }
114     // For align 2: Scale of 2*sizeof(float) is used (maximum supported scale)
115     constexpr int align_ = (align > 2) ? 1 : align;
116     v->simdInternal_     = _mm512_i32gather_ps(offset.simdInternal_, base, sizeof(float) * align_);
117     // Gather remaining elements. Avoid extra multiplication (new align is 1 or 2).
118     gatherLoadBySimdIntTranspose<align_>(base + 1, offset, Fargs...);
119 }
120
121 template<int align, typename... Targs>
122 static inline void gmx_simdcall
123                    gatherLoadUBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v, Targs... Fargs)
124 {
125     gatherLoadBySimdIntTranspose<align>(base, offset, v, Fargs...);
126 }
127
128 template<int align, typename... Targs>
129 static inline void gmx_simdcall
130                    gatherLoadTranspose(const float* base, const std::int32_t offset[], SimdFloat* v, Targs... Fargs)
131 {
132     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdFInt32Tag()), v, Fargs...);
133 }
134
135 template<int align, typename... Targs>
136 static inline void gmx_simdcall
137                    gatherLoadUTranspose(const float* base, const std::int32_t offset[], SimdFloat* v, Targs... Fargs)
138 {
139     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdFInt32Tag()), v, Fargs...);
140 }
141
142 template<int align>
143 static inline void gmx_simdcall
144                    transposeScatterStoreU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
145 {
146     SimdFInt32 simdoffset = simdLoad(offset, SimdFInt32Tag());
147     if (align > 2)
148     {
149         simdoffset = fastMultiply<align>(simdoffset);
150     }
151     constexpr size_t scale = (align > 2) ? sizeof(float) : sizeof(float) * align;
152
153     _mm512_i32scatter_ps(base, simdoffset.simdInternal_, v0.simdInternal_, scale);
154     _mm512_i32scatter_ps(&(base[1]), simdoffset.simdInternal_, v1.simdInternal_, scale);
155     _mm512_i32scatter_ps(&(base[2]), simdoffset.simdInternal_, v2.simdInternal_, scale);
156 }
157
158 template<int align>
159 static inline void gmx_simdcall
160                    transposeScatterIncrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
161 {
162     __m512                                   t[4], t5, t6, t7, t8;
163     int                                      i;
164     alignas(GMX_SIMD_ALIGNMENT) std::int32_t o[16];
165     store(o, fastMultiply<align>(simdLoad(offset, SimdFInt32Tag())));
166     if (align < 4)
167     {
168         t5   = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
169         t6   = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
170         t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
171         t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
172         t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
173         t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
174         for (i = 0; i < 4; i++)
175         {
176             _mm512_mask_storeu_ps(base + o[i], avx512Int2Mask(7),
177                                   _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[i]),
178                                                                     _mm512_castps512_ps128(t[i]))));
179             _mm512_mask_storeu_ps(base + o[4 + i], avx512Int2Mask(7),
180                                   _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[4 + i]),
181                                                                     _mm512_extractf32x4_ps(t[i], 1))));
182             _mm512_mask_storeu_ps(base + o[8 + i], avx512Int2Mask(7),
183                                   _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[8 + i]),
184                                                                     _mm512_extractf32x4_ps(t[i], 2))));
185             _mm512_mask_storeu_ps(base + o[12 + i], avx512Int2Mask(7),
186                                   _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[12 + i]),
187                                                                     _mm512_extractf32x4_ps(t[i], 3))));
188         }
189     }
190     else
191     {
192         // One could use shuffle here too if it is OK to overwrite the padded elements for alignment
193         t5   = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
194         t6   = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
195         t7   = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
196         t8   = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
197         t[0] = _mm512_unpacklo_ps(t5, t7); // x0 y0 z0  0 | x4 y4 z4 0
198         t[1] = _mm512_unpackhi_ps(t5, t7); // x1 y1 z1  0 | x5 y5 z5 0
199         t[2] = _mm512_unpacklo_ps(t6, t8); // x2 y2 z2  0 | x6 y6 z6 0
200         t[3] = _mm512_unpackhi_ps(t6, t8); // x3 y3 z3  0 | x7 y7 z7 0
201         if (align % 4 == 0)
202         {
203             for (i = 0; i < 4; i++)
204             {
205                 _mm_store_ps(base + o[i],
206                              _mm_add_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
207                 _mm_store_ps(base + o[4 + i],
208                              _mm_add_ps(_mm_load_ps(base + o[4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
209                 _mm_store_ps(base + o[8 + i],
210                              _mm_add_ps(_mm_load_ps(base + o[8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
211                 _mm_store_ps(base + o[12 + i], _mm_add_ps(_mm_load_ps(base + o[12 + i]),
212                                                           _mm512_extractf32x4_ps(t[i], 3)));
213             }
214         }
215         else
216         {
217             for (i = 0; i < 4; i++)
218             {
219                 _mm_storeu_ps(base + o[i],
220                               _mm_add_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
221                 _mm_storeu_ps(base + o[4 + i], _mm_add_ps(_mm_loadu_ps(base + o[4 + i]),
222                                                           _mm512_extractf32x4_ps(t[i], 1)));
223                 _mm_storeu_ps(base + o[8 + i], _mm_add_ps(_mm_loadu_ps(base + o[8 + i]),
224                                                           _mm512_extractf32x4_ps(t[i], 2)));
225                 _mm_storeu_ps(base + o[12 + i], _mm_add_ps(_mm_loadu_ps(base + o[12 + i]),
226                                                            _mm512_extractf32x4_ps(t[i], 3)));
227             }
228         }
229     }
230 }
231
232 template<int align>
233 static inline void gmx_simdcall
234                    transposeScatterDecrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
235 {
236     __m512                                   t[4], t5, t6, t7, t8;
237     int                                      i;
238     alignas(GMX_SIMD_ALIGNMENT) std::int32_t o[16];
239     store(o, fastMultiply<align>(simdLoad(offset, SimdFInt32Tag())));
240     if (align < 4)
241     {
242         t5   = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
243         t6   = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
244         t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
245         t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
246         t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
247         t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
248         for (i = 0; i < 4; i++)
249         {
250             _mm512_mask_storeu_ps(base + o[i], avx512Int2Mask(7),
251                                   _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[i]),
252                                                                     _mm512_castps512_ps128(t[i]))));
253             _mm512_mask_storeu_ps(base + o[4 + i], avx512Int2Mask(7),
254                                   _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[4 + i]),
255                                                                     _mm512_extractf32x4_ps(t[i], 1))));
256             _mm512_mask_storeu_ps(base + o[8 + i], avx512Int2Mask(7),
257                                   _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[8 + i]),
258                                                                     _mm512_extractf32x4_ps(t[i], 2))));
259             _mm512_mask_storeu_ps(base + o[12 + i], avx512Int2Mask(7),
260                                   _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[12 + i]),
261                                                                     _mm512_extractf32x4_ps(t[i], 3))));
262         }
263     }
264     else
265     {
266         // One could use shuffle here too if it is OK to overwrite the padded elements for alignment
267         t5   = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
268         t6   = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
269         t7   = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
270         t8   = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
271         t[0] = _mm512_unpacklo_ps(t5, t7); // x0 y0 z0  0 | x4 y4 z4 0
272         t[1] = _mm512_unpackhi_ps(t5, t7); // x1 y1 z1  0 | x5 y5 z5 0
273         t[2] = _mm512_unpacklo_ps(t6, t8); // x2 y2 z2  0 | x6 y6 z6 0
274         t[3] = _mm512_unpackhi_ps(t6, t8); // x3 y3 z3  0 | x7 y7 z7 0
275         if (align % 4 == 0)
276         {
277             for (i = 0; i < 4; i++)
278             {
279                 _mm_store_ps(base + o[i],
280                              _mm_sub_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
281                 _mm_store_ps(base + o[4 + i],
282                              _mm_sub_ps(_mm_load_ps(base + o[4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
283                 _mm_store_ps(base + o[8 + i],
284                              _mm_sub_ps(_mm_load_ps(base + o[8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
285                 _mm_store_ps(base + o[12 + i], _mm_sub_ps(_mm_load_ps(base + o[12 + i]),
286                                                           _mm512_extractf32x4_ps(t[i], 3)));
287             }
288         }
289         else
290         {
291             for (i = 0; i < 4; i++)
292             {
293                 _mm_storeu_ps(base + o[i],
294                               _mm_sub_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
295                 _mm_storeu_ps(base + o[4 + i], _mm_sub_ps(_mm_loadu_ps(base + o[4 + i]),
296                                                           _mm512_extractf32x4_ps(t[i], 1)));
297                 _mm_storeu_ps(base + o[8 + i], _mm_sub_ps(_mm_loadu_ps(base + o[8 + i]),
298                                                           _mm512_extractf32x4_ps(t[i], 2)));
299                 _mm_storeu_ps(base + o[12 + i], _mm_sub_ps(_mm_loadu_ps(base + o[12 + i]),
300                                                            _mm512_extractf32x4_ps(t[i], 3)));
301             }
302         }
303     }
304 }
305
306 static inline void gmx_simdcall expandScalarsToTriplets(SimdFloat  scalar,
307                                                         SimdFloat* triplets0,
308                                                         SimdFloat* triplets1,
309                                                         SimdFloat* triplets2)
310 {
311     triplets0->simdInternal_ = _mm512_permutexvar_ps(
312             _mm512_set_epi32(5, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1, 0, 0, 0), scalar.simdInternal_);
313     triplets1->simdInternal_ = _mm512_permutexvar_ps(
314             _mm512_set_epi32(10, 10, 9, 9, 9, 8, 8, 8, 7, 7, 7, 6, 6, 6, 5, 5), scalar.simdInternal_);
315     triplets2->simdInternal_ = _mm512_permutexvar_ps(
316             _mm512_set_epi32(15, 15, 15, 14, 14, 14, 13, 13, 13, 12, 12, 12, 11, 11, 11, 10),
317             scalar.simdInternal_);
318 }
319
320
321 static inline float gmx_simdcall reduceIncr4ReturnSum(float* m, SimdFloat v0, SimdFloat v1, SimdFloat v2, SimdFloat v3)
322 {
323     __m512 t0, t1, t2;
324     __m128 t3, t4;
325
326     assert(std::size_t(m) % 16 == 0);
327
328     t0 = _mm512_add_ps(v0.simdInternal_, _mm512_permute_ps(v0.simdInternal_, 0x4E));
329     t0 = _mm512_mask_add_ps(t0, avx512Int2Mask(0xCCCC), v2.simdInternal_,
330                             _mm512_permute_ps(v2.simdInternal_, 0x4E));
331     t1 = _mm512_add_ps(v1.simdInternal_, _mm512_permute_ps(v1.simdInternal_, 0x4E));
332     t1 = _mm512_mask_add_ps(t1, avx512Int2Mask(0xCCCC), v3.simdInternal_,
333                             _mm512_permute_ps(v3.simdInternal_, 0x4E));
334     t2 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0xB1));
335     t2 = _mm512_mask_add_ps(t2, avx512Int2Mask(0xAAAA), t1, _mm512_permute_ps(t1, 0xB1));
336
337     t2 = _mm512_add_ps(t2, _mm512_shuffle_f32x4(t2, t2, 0x4E));
338     t2 = _mm512_add_ps(t2, _mm512_shuffle_f32x4(t2, t2, 0xB1));
339
340     t3 = _mm512_castps512_ps128(t2);
341     t4 = _mm_load_ps(m);
342     t4 = _mm_add_ps(t4, t3);
343     _mm_store_ps(m, t4);
344
345     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0x4E));
346     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0xB1));
347
348     return _mm_cvtss_f32(t3);
349 }
350
351 static inline SimdFloat gmx_simdcall loadDualHsimd(const float* m0, const float* m1)
352 {
353     assert(std::size_t(m0) % 32 == 0);
354     assert(std::size_t(m1) % 32 == 0);
355
356     return { _mm512_castpd_ps(_mm512_insertf64x4(
357             _mm512_castpd256_pd512(_mm256_load_pd(reinterpret_cast<const double*>(m0))),
358             _mm256_load_pd(reinterpret_cast<const double*>(m1)), 1)) };
359 }
360
361 static inline SimdFloat gmx_simdcall loadDuplicateHsimd(const float* m)
362 {
363     assert(std::size_t(m) % 32 == 0);
364     return { _mm512_castpd_ps(_mm512_broadcast_f64x4(_mm256_load_pd(reinterpret_cast<const double*>(m)))) };
365 }
366
367 static inline SimdFloat gmx_simdcall loadU1DualHsimd(const float* m)
368 {
369     return { _mm512_shuffle_f32x4(_mm512_broadcastss_ps(_mm_load_ss(m)),
370                                   _mm512_broadcastss_ps(_mm_load_ss(m + 1)), 0x44) };
371 }
372
373
374 static inline void gmx_simdcall storeDualHsimd(float* m0, float* m1, SimdFloat a)
375 {
376     assert(std::size_t(m0) % 32 == 0);
377     assert(std::size_t(m1) % 32 == 0);
378
379     _mm256_store_ps(m0, _mm512_castps512_ps256(a.simdInternal_));
380     _mm256_store_pd(reinterpret_cast<double*>(m1),
381                     _mm512_extractf64x4_pd(_mm512_castps_pd(a.simdInternal_), 1));
382 }
383
384 static inline void gmx_simdcall incrDualHsimd(float* m0, float* m1, SimdFloat a)
385 {
386     assert(std::size_t(m0) % 32 == 0);
387     assert(std::size_t(m1) % 32 == 0);
388
389     __m256 x;
390
391     // Lower half
392     x = _mm256_load_ps(m0);
393     x = _mm256_add_ps(x, _mm512_castps512_ps256(a.simdInternal_));
394     _mm256_store_ps(m0, x);
395
396     // Upper half
397     x = _mm256_load_ps(m1);
398     x = _mm256_add_ps(x, _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a.simdInternal_), 1)));
399     _mm256_store_ps(m1, x);
400 }
401
402 static inline void gmx_simdcall decr3Hsimd(float* m, SimdFloat a0, SimdFloat a1, SimdFloat a2)
403 {
404     decrHsimd(m, a0);
405     decrHsimd(m + GMX_SIMD_FLOAT_WIDTH / 2, a1);
406     decrHsimd(m + GMX_SIMD_FLOAT_WIDTH, a2);
407 }
408
409
410 template<int align>
411 static inline void gmx_simdcall gatherLoadTransposeHsimd(const float*       base0,
412                                                          const float*       base1,
413                                                          const std::int32_t offset[],
414                                                          SimdFloat*         v0,
415                                                          SimdFloat*         v1)
416 {
417     __m256i idx;
418     __m512  tmp1, tmp2;
419
420     assert(std::size_t(offset) % 32 == 0);
421     assert(std::size_t(base0) % 8 == 0);
422     assert(std::size_t(base1) % 8 == 0);
423
424     idx = _mm256_load_si256(reinterpret_cast<const __m256i*>(offset));
425
426     static_assert(align == 2 || align == 4, "If more are needed use fastMultiply");
427     if (align == 4)
428     {
429         idx = _mm256_slli_epi32(idx, 1);
430     }
431
432     tmp1 = _mm512_castpd_ps(
433             _mm512_i32gather_pd(idx, reinterpret_cast<const double*>(base0), sizeof(double)));
434     tmp2 = _mm512_castpd_ps(
435             _mm512_i32gather_pd(idx, reinterpret_cast<const double*>(base1), sizeof(double)));
436
437     v0->simdInternal_ = _mm512_mask_moveldup_ps(tmp1, 0xAAAA, tmp2);
438     v1->simdInternal_ = _mm512_mask_movehdup_ps(tmp2, 0x5555, tmp1);
439
440     v0->simdInternal_ = _mm512_permutexvar_ps(
441             _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), v0->simdInternal_);
442     v1->simdInternal_ = _mm512_permutexvar_ps(
443             _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), v1->simdInternal_);
444 }
445
446 static inline float gmx_simdcall reduceIncr4ReturnSumHsimd(float* m, SimdFloat v0, SimdFloat v1)
447 {
448     __m512 t0, t1;
449     __m128 t2, t3;
450
451     assert(std::size_t(m) % 16 == 0);
452
453     t0 = _mm512_shuffle_f32x4(v0.simdInternal_, v1.simdInternal_, 0x88);
454     t1 = _mm512_shuffle_f32x4(v0.simdInternal_, v1.simdInternal_, 0xDD);
455     t0 = _mm512_add_ps(t0, t1);
456     t0 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0x4E));
457     t0 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0xB1));
458     t0 = _mm512_maskz_compress_ps(avx512Int2Mask(0x1111), t0);
459
460     t3 = _mm512_castps512_ps128(t0);
461     t2 = _mm_load_ps(m);
462     t2 = _mm_add_ps(t2, t3);
463     _mm_store_ps(m, t2);
464
465     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0x4E));
466     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0xB1));
467
468     return _mm_cvtss_f32(t3);
469 }
470
471 static inline SimdFloat gmx_simdcall loadUNDuplicate4(const float* f)
472 {
473     return { _mm512_permute_ps(_mm512_maskz_expandloadu_ps(0x1111, f), 0) };
474 }
475
476 static inline SimdFloat gmx_simdcall load4DuplicateN(const float* f)
477 {
478     return { _mm512_broadcast_f32x4(_mm_load_ps(f)) };
479 }
480
481 static inline SimdFloat gmx_simdcall loadU4NOffset(const float* f, int offset)
482 {
483     const __m256i idx = _mm256_setr_epi32(0, 0, 1, 1, 2, 2, 3, 3);
484     const __m256i gdx = _mm256_add_epi32(_mm256_setr_epi32(0, 2, 0, 2, 0, 2, 0, 2),
485                                          _mm256_mullo_epi32(idx, _mm256_set1_epi32(offset)));
486     return { _mm512_castpd_ps(_mm512_i32gather_pd(gdx, reinterpret_cast<const double*>(f), sizeof(float))) };
487 }
488
489 } // namespace gmx
490
491 #endif // GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H