Apply re-formatting to C++ in src/ tree.
[alexxy/gromacs.git] / src / gromacs / simd / impl_x86_avx_512 / impl_x86_avx_512_util_float.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014,2015,2016,2017,2018 by the GROMACS development team.
5  * Copyright (c) 2019,2020, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36
37 #ifndef GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H
38 #define GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H
39
40 #include "config.h"
41
42 #include <cassert>
43 #include <cstdint>
44
45 #include <immintrin.h>
46
47 #include "gromacs/utility/basedefinitions.h"
48
49 #include "impl_x86_avx_512_general.h"
50 #include "impl_x86_avx_512_simd_float.h"
51
52 namespace gmx
53 {
54
55 static const int c_simdBestPairAlignmentFloat = 2;
56
57 namespace
58 {
59 // Multiply function optimized for powers of 2, for which it is done by
60 // shifting. Currently up to 8 is accelerated. Could be accelerated for any
61 // number with a constexpr log2 function.
62 template<int n>
63 static inline SimdFInt32 fastMultiply(SimdFInt32 x)
64 {
65     if (n == 2)
66     {
67         return _mm512_slli_epi32(x.simdInternal_, 1);
68     }
69     else if (n == 4)
70     {
71         return _mm512_slli_epi32(x.simdInternal_, 2);
72     }
73     else if (n == 8)
74     {
75         return _mm512_slli_epi32(x.simdInternal_, 3);
76     }
77     else
78     {
79         return x * n;
80     }
81 }
82
83 template<int align>
84 static inline void gmx_simdcall gatherLoadBySimdIntTranspose(const float*, SimdFInt32)
85 {
86     // Nothing to do. Termination of recursion.
87 }
88
89 /* This is an internal helper function used by decr3Hsimd(...).
90  */
91 inline void gmx_simdcall decrHsimd(float* m, SimdFloat a)
92 {
93     __m256 t;
94
95     assert(std::size_t(m) % 32 == 0);
96
97     a.simdInternal_ = _mm512_add_ps(a.simdInternal_,
98                                     _mm512_shuffle_f32x4(a.simdInternal_, a.simdInternal_, 0xEE));
99     t               = _mm256_load_ps(m);
100     t               = _mm256_sub_ps(t, _mm512_castps512_ps256(a.simdInternal_));
101     _mm256_store_ps(m, t);
102 }
103 } // namespace
104
105 template<int align, typename... Targs>
106 static inline void gmx_simdcall
107                    gatherLoadBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v, Targs... Fargs)
108 {
109     // For align 1 or 2: No multiplication of offset is needed
110     if (align > 2)
111     {
112         offset = fastMultiply<align>(offset);
113     }
114     // For align 2: Scale of 2*sizeof(float) is used (maximum supported scale)
115     constexpr int align_ = (align > 2) ? 1 : align;
116     v->simdInternal_     = _mm512_i32gather_ps(offset.simdInternal_, base, sizeof(float) * align_);
117     // Gather remaining elements. Avoid extra multiplication (new align is 1 or 2).
118     gatherLoadBySimdIntTranspose<align_>(base + 1, offset, Fargs...);
119 }
120
121 template<int align, typename... Targs>
122 static inline void gmx_simdcall
123                    gatherLoadUBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v, Targs... Fargs)
124 {
125     gatherLoadBySimdIntTranspose<align>(base, offset, v, Fargs...);
126 }
127
128 template<int align, typename... Targs>
129 static inline void gmx_simdcall
130                    gatherLoadTranspose(const float* base, const std::int32_t offset[], SimdFloat* v, Targs... Fargs)
131 {
132     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdFInt32Tag()), v, Fargs...);
133 }
134
135 template<int align, typename... Targs>
136 static inline void gmx_simdcall
137                    gatherLoadUTranspose(const float* base, const std::int32_t offset[], SimdFloat* v, Targs... Fargs)
138 {
139     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdFInt32Tag()), v, Fargs...);
140 }
141
142 template<int align>
143 static inline void gmx_simdcall
144                    transposeScatterStoreU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
145 {
146     SimdFInt32 simdoffset = simdLoad(offset, SimdFInt32Tag());
147     if (align > 2)
148     {
149         simdoffset = fastMultiply<align>(simdoffset);
150     }
151     constexpr size_t scale = (align > 2) ? sizeof(float) : sizeof(float) * align;
152
153     _mm512_i32scatter_ps(base, simdoffset.simdInternal_, v0.simdInternal_, scale);
154     _mm512_i32scatter_ps(&(base[1]), simdoffset.simdInternal_, v1.simdInternal_, scale);
155     _mm512_i32scatter_ps(&(base[2]), simdoffset.simdInternal_, v2.simdInternal_, scale);
156 }
157
158 template<int align>
159 static inline void gmx_simdcall
160                    transposeScatterIncrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
161 {
162     __m512                                   t[4], t5, t6, t7, t8;
163     int                                      i;
164     alignas(GMX_SIMD_ALIGNMENT) std::int32_t o[16];
165     store(o, fastMultiply<align>(simdLoad(offset, SimdFInt32Tag())));
166     if (align < 4)
167     {
168         t5   = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
169         t6   = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
170         t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
171         t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
172         t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
173         t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
174         for (i = 0; i < 4; i++)
175         {
176             _mm512_mask_storeu_ps(base + o[i],
177                                   avx512Int2Mask(7),
178                                   _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[i]),
179                                                                     _mm512_castps512_ps128(t[i]))));
180             _mm512_mask_storeu_ps(base + o[4 + i],
181                                   avx512Int2Mask(7),
182                                   _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[4 + i]),
183                                                                     _mm512_extractf32x4_ps(t[i], 1))));
184             _mm512_mask_storeu_ps(base + o[8 + i],
185                                   avx512Int2Mask(7),
186                                   _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[8 + i]),
187                                                                     _mm512_extractf32x4_ps(t[i], 2))));
188             _mm512_mask_storeu_ps(base + o[12 + i],
189                                   avx512Int2Mask(7),
190                                   _mm512_castps128_ps512(_mm_add_ps(_mm_loadu_ps(base + o[12 + i]),
191                                                                     _mm512_extractf32x4_ps(t[i], 3))));
192         }
193     }
194     else
195     {
196         // One could use shuffle here too if it is OK to overwrite the padded elements for alignment
197         t5   = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
198         t6   = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
199         t7   = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
200         t8   = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
201         t[0] = _mm512_unpacklo_ps(t5, t7); // x0 y0 z0  0 | x4 y4 z4 0
202         t[1] = _mm512_unpackhi_ps(t5, t7); // x1 y1 z1  0 | x5 y5 z5 0
203         t[2] = _mm512_unpacklo_ps(t6, t8); // x2 y2 z2  0 | x6 y6 z6 0
204         t[3] = _mm512_unpackhi_ps(t6, t8); // x3 y3 z3  0 | x7 y7 z7 0
205         if (align % 4 == 0)
206         {
207             for (i = 0; i < 4; i++)
208             {
209                 _mm_store_ps(base + o[i],
210                              _mm_add_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
211                 _mm_store_ps(base + o[4 + i],
212                              _mm_add_ps(_mm_load_ps(base + o[4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
213                 _mm_store_ps(base + o[8 + i],
214                              _mm_add_ps(_mm_load_ps(base + o[8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
215                 _mm_store_ps(base + o[12 + i],
216                              _mm_add_ps(_mm_load_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
217             }
218         }
219         else
220         {
221             for (i = 0; i < 4; i++)
222             {
223                 _mm_storeu_ps(base + o[i],
224                               _mm_add_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
225                 _mm_storeu_ps(base + o[4 + i],
226                               _mm_add_ps(_mm_loadu_ps(base + o[4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
227                 _mm_storeu_ps(base + o[8 + i],
228                               _mm_add_ps(_mm_loadu_ps(base + o[8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
229                 _mm_storeu_ps(base + o[12 + i],
230                               _mm_add_ps(_mm_loadu_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
231             }
232         }
233     }
234 }
235
236 template<int align>
237 static inline void gmx_simdcall
238                    transposeScatterDecrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
239 {
240     __m512                                   t[4], t5, t6, t7, t8;
241     int                                      i;
242     alignas(GMX_SIMD_ALIGNMENT) std::int32_t o[16];
243     store(o, fastMultiply<align>(simdLoad(offset, SimdFInt32Tag())));
244     if (align < 4)
245     {
246         t5   = _mm512_unpacklo_ps(v0.simdInternal_, v1.simdInternal_);
247         t6   = _mm512_unpackhi_ps(v0.simdInternal_, v1.simdInternal_);
248         t[0] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(0, 0, 1, 0));
249         t[1] = _mm512_shuffle_ps(t5, v2.simdInternal_, _MM_SHUFFLE(1, 1, 3, 2));
250         t[2] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(2, 2, 1, 0));
251         t[3] = _mm512_shuffle_ps(t6, v2.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
252         for (i = 0; i < 4; i++)
253         {
254             _mm512_mask_storeu_ps(base + o[i],
255                                   avx512Int2Mask(7),
256                                   _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[i]),
257                                                                     _mm512_castps512_ps128(t[i]))));
258             _mm512_mask_storeu_ps(base + o[4 + i],
259                                   avx512Int2Mask(7),
260                                   _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[4 + i]),
261                                                                     _mm512_extractf32x4_ps(t[i], 1))));
262             _mm512_mask_storeu_ps(base + o[8 + i],
263                                   avx512Int2Mask(7),
264                                   _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[8 + i]),
265                                                                     _mm512_extractf32x4_ps(t[i], 2))));
266             _mm512_mask_storeu_ps(base + o[12 + i],
267                                   avx512Int2Mask(7),
268                                   _mm512_castps128_ps512(_mm_sub_ps(_mm_loadu_ps(base + o[12 + i]),
269                                                                     _mm512_extractf32x4_ps(t[i], 3))));
270         }
271     }
272     else
273     {
274         // One could use shuffle here too if it is OK to overwrite the padded elements for alignment
275         t5   = _mm512_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
276         t6   = _mm512_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
277         t7   = _mm512_unpacklo_ps(v1.simdInternal_, _mm512_setzero_ps());
278         t8   = _mm512_unpackhi_ps(v1.simdInternal_, _mm512_setzero_ps());
279         t[0] = _mm512_unpacklo_ps(t5, t7); // x0 y0 z0  0 | x4 y4 z4 0
280         t[1] = _mm512_unpackhi_ps(t5, t7); // x1 y1 z1  0 | x5 y5 z5 0
281         t[2] = _mm512_unpacklo_ps(t6, t8); // x2 y2 z2  0 | x6 y6 z6 0
282         t[3] = _mm512_unpackhi_ps(t6, t8); // x3 y3 z3  0 | x7 y7 z7 0
283         if (align % 4 == 0)
284         {
285             for (i = 0; i < 4; i++)
286             {
287                 _mm_store_ps(base + o[i],
288                              _mm_sub_ps(_mm_load_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
289                 _mm_store_ps(base + o[4 + i],
290                              _mm_sub_ps(_mm_load_ps(base + o[4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
291                 _mm_store_ps(base + o[8 + i],
292                              _mm_sub_ps(_mm_load_ps(base + o[8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
293                 _mm_store_ps(base + o[12 + i],
294                              _mm_sub_ps(_mm_load_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
295             }
296         }
297         else
298         {
299             for (i = 0; i < 4; i++)
300             {
301                 _mm_storeu_ps(base + o[i],
302                               _mm_sub_ps(_mm_loadu_ps(base + o[i]), _mm512_castps512_ps128(t[i])));
303                 _mm_storeu_ps(base + o[4 + i],
304                               _mm_sub_ps(_mm_loadu_ps(base + o[4 + i]), _mm512_extractf32x4_ps(t[i], 1)));
305                 _mm_storeu_ps(base + o[8 + i],
306                               _mm_sub_ps(_mm_loadu_ps(base + o[8 + i]), _mm512_extractf32x4_ps(t[i], 2)));
307                 _mm_storeu_ps(base + o[12 + i],
308                               _mm_sub_ps(_mm_loadu_ps(base + o[12 + i]), _mm512_extractf32x4_ps(t[i], 3)));
309             }
310         }
311     }
312 }
313
314 static inline void gmx_simdcall expandScalarsToTriplets(SimdFloat  scalar,
315                                                         SimdFloat* triplets0,
316                                                         SimdFloat* triplets1,
317                                                         SimdFloat* triplets2)
318 {
319     triplets0->simdInternal_ = _mm512_permutexvar_ps(
320             _mm512_set_epi32(5, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1, 0, 0, 0), scalar.simdInternal_);
321     triplets1->simdInternal_ = _mm512_permutexvar_ps(
322             _mm512_set_epi32(10, 10, 9, 9, 9, 8, 8, 8, 7, 7, 7, 6, 6, 6, 5, 5), scalar.simdInternal_);
323     triplets2->simdInternal_ = _mm512_permutexvar_ps(
324             _mm512_set_epi32(15, 15, 15, 14, 14, 14, 13, 13, 13, 12, 12, 12, 11, 11, 11, 10),
325             scalar.simdInternal_);
326 }
327
328
329 static inline float gmx_simdcall reduceIncr4ReturnSum(float* m, SimdFloat v0, SimdFloat v1, SimdFloat v2, SimdFloat v3)
330 {
331     __m512 t0, t1, t2;
332     __m128 t3, t4;
333
334     assert(std::size_t(m) % 16 == 0);
335
336     t0 = _mm512_add_ps(v0.simdInternal_, _mm512_permute_ps(v0.simdInternal_, 0x4E));
337     t0 = _mm512_mask_add_ps(
338             t0, avx512Int2Mask(0xCCCC), v2.simdInternal_, _mm512_permute_ps(v2.simdInternal_, 0x4E));
339     t1 = _mm512_add_ps(v1.simdInternal_, _mm512_permute_ps(v1.simdInternal_, 0x4E));
340     t1 = _mm512_mask_add_ps(
341             t1, avx512Int2Mask(0xCCCC), v3.simdInternal_, _mm512_permute_ps(v3.simdInternal_, 0x4E));
342     t2 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0xB1));
343     t2 = _mm512_mask_add_ps(t2, avx512Int2Mask(0xAAAA), t1, _mm512_permute_ps(t1, 0xB1));
344
345     t2 = _mm512_add_ps(t2, _mm512_shuffle_f32x4(t2, t2, 0x4E));
346     t2 = _mm512_add_ps(t2, _mm512_shuffle_f32x4(t2, t2, 0xB1));
347
348     t3 = _mm512_castps512_ps128(t2);
349     t4 = _mm_load_ps(m);
350     t4 = _mm_add_ps(t4, t3);
351     _mm_store_ps(m, t4);
352
353     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0x4E));
354     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0xB1));
355
356     return _mm_cvtss_f32(t3);
357 }
358
359 static inline SimdFloat gmx_simdcall loadDualHsimd(const float* m0, const float* m1)
360 {
361     assert(std::size_t(m0) % 32 == 0);
362     assert(std::size_t(m1) % 32 == 0);
363
364     return { _mm512_castpd_ps(_mm512_insertf64x4(
365             _mm512_castpd256_pd512(_mm256_load_pd(reinterpret_cast<const double*>(m0))),
366             _mm256_load_pd(reinterpret_cast<const double*>(m1)),
367             1)) };
368 }
369
370 static inline SimdFloat gmx_simdcall loadDuplicateHsimd(const float* m)
371 {
372     assert(std::size_t(m) % 32 == 0);
373     return { _mm512_castpd_ps(_mm512_broadcast_f64x4(_mm256_load_pd(reinterpret_cast<const double*>(m)))) };
374 }
375
376 static inline SimdFloat gmx_simdcall loadU1DualHsimd(const float* m)
377 {
378     return { _mm512_shuffle_f32x4(
379             _mm512_broadcastss_ps(_mm_load_ss(m)), _mm512_broadcastss_ps(_mm_load_ss(m + 1)), 0x44) };
380 }
381
382
383 static inline void gmx_simdcall storeDualHsimd(float* m0, float* m1, SimdFloat a)
384 {
385     assert(std::size_t(m0) % 32 == 0);
386     assert(std::size_t(m1) % 32 == 0);
387
388     _mm256_store_ps(m0, _mm512_castps512_ps256(a.simdInternal_));
389     _mm256_store_pd(reinterpret_cast<double*>(m1),
390                     _mm512_extractf64x4_pd(_mm512_castps_pd(a.simdInternal_), 1));
391 }
392
393 static inline void gmx_simdcall incrDualHsimd(float* m0, float* m1, SimdFloat a)
394 {
395     assert(std::size_t(m0) % 32 == 0);
396     assert(std::size_t(m1) % 32 == 0);
397
398     __m256 x;
399
400     // Lower half
401     x = _mm256_load_ps(m0);
402     x = _mm256_add_ps(x, _mm512_castps512_ps256(a.simdInternal_));
403     _mm256_store_ps(m0, x);
404
405     // Upper half
406     x = _mm256_load_ps(m1);
407     x = _mm256_add_ps(x, _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a.simdInternal_), 1)));
408     _mm256_store_ps(m1, x);
409 }
410
411 static inline void gmx_simdcall decr3Hsimd(float* m, SimdFloat a0, SimdFloat a1, SimdFloat a2)
412 {
413     decrHsimd(m, a0);
414     decrHsimd(m + GMX_SIMD_FLOAT_WIDTH / 2, a1);
415     decrHsimd(m + GMX_SIMD_FLOAT_WIDTH, a2);
416 }
417
418
419 template<int align>
420 static inline void gmx_simdcall gatherLoadTransposeHsimd(const float*       base0,
421                                                          const float*       base1,
422                                                          const std::int32_t offset[],
423                                                          SimdFloat*         v0,
424                                                          SimdFloat*         v1)
425 {
426     __m256i idx;
427     __m512  tmp1, tmp2;
428
429     assert(std::size_t(offset) % 32 == 0);
430     assert(std::size_t(base0) % 8 == 0);
431     assert(std::size_t(base1) % 8 == 0);
432
433     idx = _mm256_load_si256(reinterpret_cast<const __m256i*>(offset));
434
435     static_assert(align == 2 || align == 4, "If more are needed use fastMultiply");
436     if (align == 4)
437     {
438         idx = _mm256_slli_epi32(idx, 1);
439     }
440
441     tmp1 = _mm512_castpd_ps(
442             _mm512_i32gather_pd(idx, reinterpret_cast<const double*>(base0), sizeof(double)));
443     tmp2 = _mm512_castpd_ps(
444             _mm512_i32gather_pd(idx, reinterpret_cast<const double*>(base1), sizeof(double)));
445
446     v0->simdInternal_ = _mm512_mask_moveldup_ps(tmp1, 0xAAAA, tmp2);
447     v1->simdInternal_ = _mm512_mask_movehdup_ps(tmp2, 0x5555, tmp1);
448
449     v0->simdInternal_ = _mm512_permutexvar_ps(
450             _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), v0->simdInternal_);
451     v1->simdInternal_ = _mm512_permutexvar_ps(
452             _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), v1->simdInternal_);
453 }
454
455 static inline float gmx_simdcall reduceIncr4ReturnSumHsimd(float* m, SimdFloat v0, SimdFloat v1)
456 {
457     __m512 t0, t1;
458     __m128 t2, t3;
459
460     assert(std::size_t(m) % 16 == 0);
461
462     t0 = _mm512_shuffle_f32x4(v0.simdInternal_, v1.simdInternal_, 0x88);
463     t1 = _mm512_shuffle_f32x4(v0.simdInternal_, v1.simdInternal_, 0xDD);
464     t0 = _mm512_add_ps(t0, t1);
465     t0 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0x4E));
466     t0 = _mm512_add_ps(t0, _mm512_permute_ps(t0, 0xB1));
467     t0 = _mm512_maskz_compress_ps(avx512Int2Mask(0x1111), t0);
468
469     t3 = _mm512_castps512_ps128(t0);
470     t2 = _mm_load_ps(m);
471     t2 = _mm_add_ps(t2, t3);
472     _mm_store_ps(m, t2);
473
474     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0x4E));
475     t3 = _mm_add_ps(t3, _mm_permute_ps(t3, 0xB1));
476
477     return _mm_cvtss_f32(t3);
478 }
479
480 static inline SimdFloat gmx_simdcall loadUNDuplicate4(const float* f)
481 {
482     return { _mm512_permute_ps(_mm512_maskz_expandloadu_ps(0x1111, f), 0) };
483 }
484
485 static inline SimdFloat gmx_simdcall load4DuplicateN(const float* f)
486 {
487     return { _mm512_broadcast_f32x4(_mm_load_ps(f)) };
488 }
489
490 static inline SimdFloat gmx_simdcall loadU4NOffset(const float* f, int offset)
491 {
492     const __m256i idx = _mm256_setr_epi32(0, 0, 1, 1, 2, 2, 3, 3);
493     const __m256i gdx = _mm256_add_epi32(_mm256_setr_epi32(0, 2, 0, 2, 0, 2, 0, 2),
494                                          _mm256_mullo_epi32(idx, _mm256_set1_epi32(offset)));
495     return { _mm512_castpd_ps(_mm512_i32gather_pd(gdx, reinterpret_cast<const double*>(f), sizeof(float))) };
496 }
497
498 } // namespace gmx
499
500 #endif // GMX_SIMD_IMPL_X86_AVX_512_UTIL_FLOAT_H