Apply re-formatting to C++ in src/ tree.
[alexxy/gromacs.git] / src / gromacs / simd / impl_x86_mic / impl_x86_mic_util_float.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014,2015,2016,2017,2018 by the GROMACS development team.
5  * Copyright (c) 2019,2020, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36
37 #ifndef GMX_SIMD_IMPL_X86_MIC_UTIL_FLOAT_H
38 #define GMX_SIMD_IMPL_X86_MIC_UTIL_FLOAT_H
39
40 #include "config.h"
41
42 #include <cassert>
43 #include <cstdint>
44
45 #include <immintrin.h>
46
47 #include "gromacs/utility/basedefinitions.h"
48
49 #include "impl_x86_mic_simd_float.h"
50
51 namespace gmx
52 {
53
54 namespace
55 {
56 /* This is an internal helper function used by decr3Hsimd(...).
57  */
58 inline void gmx_simdcall decrHsimd(float* m, SimdFloat a)
59 {
60     __m512 t;
61
62     assert(std::size_t(m) % 32 == 0);
63
64     t = _mm512_castpd_ps(_mm512_extload_pd(
65             reinterpret_cast<const double*>(m), _MM_UPCONV_PD_NONE, _MM_BROADCAST_4X8, _MM_HINT_NONE));
66     a = _mm512_add_ps(a.simdInternal_, _mm512_permute4f128_ps(a.simdInternal_, _MM_PERM_BADC));
67     t = _mm512_sub_ps(t, a.simdInternal_);
68     _mm512_mask_packstorelo_ps(m, _mm512_int2mask(0x00FF), t);
69 }
70 } // namespace
71
72 // On MIC it is better to use scatter operations, so we define the load routines
73 // that use a SIMD offset variable first.
74
75 template<int align>
76 static inline void gmx_simdcall gatherLoadBySimdIntTranspose(const float* base,
77                                                              SimdFInt32   simdoffset,
78                                                              SimdFloat*   v0,
79                                                              SimdFloat*   v1,
80                                                              SimdFloat*   v2,
81                                                              SimdFloat*   v3)
82 {
83     assert(std::size_t(base) % 16 == 0);
84     assert(align % 4 == 0);
85
86     // All instructions might be latency ~4 on MIC, so we use shifts where we
87     // only need a single instruction (since the shift parameter is an immediate),
88     // but multiplication otherwise.
89     if (align == 4)
90     {
91         simdoffset.simdInternal_ = _mm512_slli_epi32(simdoffset.simdInternal_, 2);
92     }
93     else if (align == 8)
94     {
95         simdoffset.simdInternal_ = _mm512_slli_epi32(simdoffset.simdInternal_, 3);
96     }
97     else
98     {
99         simdoffset = simdoffset * SimdFInt32(align);
100     }
101
102     v0->simdInternal_ = _mm512_i32gather_ps(simdoffset.simdInternal_, base, sizeof(float));
103     v1->simdInternal_ = _mm512_i32gather_ps(simdoffset.simdInternal_, base + 1, sizeof(float));
104     v2->simdInternal_ = _mm512_i32gather_ps(simdoffset.simdInternal_, base + 2, sizeof(float));
105     v3->simdInternal_ = _mm512_i32gather_ps(simdoffset.simdInternal_, base + 3, sizeof(float));
106 }
107
108 template<int align>
109 static inline void gmx_simdcall
110                    gatherLoadUBySimdIntTranspose(const float* base, SimdFInt32 simdoffset, SimdFloat* v0, SimdFloat* v1)
111 {
112     // All instructions might be latency ~4 on MIC, so we use shifts where we
113     // only need a single instruction (since the shift parameter is an immediate),
114     // but multiplication otherwise.
115     // For align == 2 we can merge the constant into the scale parameter,
116     // which can take constants up to 8 in total.
117     if (align == 2)
118     {
119         v0->simdInternal_ = _mm512_i32gather_ps(simdoffset.simdInternal_, base, align * sizeof(float));
120         v1->simdInternal_ =
121                 _mm512_i32gather_ps(simdoffset.simdInternal_, base + 1, align * sizeof(float));
122     }
123     else
124     {
125         if (align == 4)
126         {
127             simdoffset.simdInternal_ = _mm512_slli_epi32(simdoffset.simdInternal_, 2);
128         }
129         else if (align == 8)
130         {
131             simdoffset.simdInternal_ = _mm512_slli_epi32(simdoffset.simdInternal_, 3);
132         }
133         else
134         {
135             simdoffset = simdoffset * SimdFInt32(align);
136         }
137         v0->simdInternal_ = _mm512_i32gather_ps(simdoffset.simdInternal_, base, sizeof(float));
138         v1->simdInternal_ = _mm512_i32gather_ps(simdoffset.simdInternal_, base + 1, sizeof(float));
139     }
140 }
141
142 template<int align>
143 static inline void gmx_simdcall
144                    gatherLoadBySimdIntTranspose(const float* base, SimdFInt32 simdoffset, SimdFloat* v0, SimdFloat* v1)
145 {
146     assert(std::size_t(base) % 8 == 0);
147     assert(align % 2 == 0);
148     gatherLoadUBySimdIntTranspose<align>(base, simdoffset, v0, v1);
149 }
150
151 template<int align>
152 static inline void gmx_simdcall gatherLoadTranspose(const float*       base,
153                                                     const std::int32_t offset[],
154                                                     SimdFloat*         v0,
155                                                     SimdFloat*         v1,
156                                                     SimdFloat*         v2,
157                                                     SimdFloat*         v3)
158 {
159     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdFInt32Tag()), v0, v1, v2, v3);
160 }
161
162 template<int align>
163 static inline void gmx_simdcall
164                    gatherLoadTranspose(const float* base, const std::int32_t offset[], SimdFloat* v0, SimdFloat* v1)
165 {
166     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdFInt32Tag()), v0, v1);
167 }
168
169 static const int c_simdBestPairAlignmentFloat = 2;
170
171 template<int align>
172 static inline void gmx_simdcall gatherLoadUTranspose(const float*       base,
173                                                      const std::int32_t offset[],
174                                                      SimdFloat*         v0,
175                                                      SimdFloat*         v1,
176                                                      SimdFloat*         v2)
177 {
178     SimdFInt32 simdoffset;
179
180     assert(std::size_t(offset) % 64 == 0);
181
182     simdoffset = simdLoad(offset, SimdFInt32Tag());
183
184     // All instructions might be latency ~4 on MIC, so we use shifts where we
185     // only need a single instruction (since the shift parameter is an immediate),
186     // but multiplication otherwise.
187     if (align == 4)
188     {
189         simdoffset.simdInternal_ = _mm512_slli_epi32(simdoffset.simdInternal_, 2);
190     }
191     else if (align == 8)
192     {
193         simdoffset.simdInternal_ = _mm512_slli_epi32(simdoffset.simdInternal_, 3);
194     }
195     else
196     {
197         simdoffset = simdoffset * SimdFInt32(align);
198     }
199
200     v0->simdInternal_ = _mm512_i32gather_ps(simdoffset.simdInternal_, base, sizeof(float));
201     v1->simdInternal_ = _mm512_i32gather_ps(simdoffset.simdInternal_, base + 1, sizeof(float));
202     v2->simdInternal_ = _mm512_i32gather_ps(simdoffset.simdInternal_, base + 2, sizeof(float));
203 }
204
205
206 template<int align>
207 static inline void gmx_simdcall
208                    transposeScatterStoreU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
209 {
210     SimdFInt32 simdoffset;
211
212     assert(std::size_t(offset) % 64 == 0);
213
214     simdoffset = simdLoad(offset, SimdFInt32Tag());
215
216     // All instructions might be latency ~4 on MIC, so we use shifts where we
217     // only need a single instruction (since the shift parameter is an immediate),
218     // but multiplication otherwise.
219     if (align == 4)
220     {
221         simdoffset.simdInternal_ = _mm512_slli_epi32(simdoffset.simdInternal_, 2);
222     }
223     else if (align == 8)
224     {
225         simdoffset.simdInternal_ = _mm512_slli_epi32(simdoffset.simdInternal_, 3);
226     }
227     else
228     {
229         simdoffset = simdoffset * SimdFInt32(align);
230     }
231
232     _mm512_i32scatter_ps(base, simdoffset.simdInternal_, v0.simdInternal_, sizeof(float));
233     _mm512_i32scatter_ps(base + 1, simdoffset.simdInternal_, v1.simdInternal_, sizeof(float));
234     _mm512_i32scatter_ps(base + 2, simdoffset.simdInternal_, v2.simdInternal_, sizeof(float));
235 }
236
237
238 template<int align>
239 static inline void gmx_simdcall
240                    transposeScatterIncrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
241 {
242     alignas(GMX_SIMD_ALIGNMENT) float rdata0[GMX_SIMD_FLOAT_WIDTH];
243     alignas(GMX_SIMD_ALIGNMENT) float rdata1[GMX_SIMD_FLOAT_WIDTH];
244     alignas(GMX_SIMD_ALIGNMENT) float rdata2[GMX_SIMD_FLOAT_WIDTH];
245
246     store(rdata0, v0);
247     store(rdata1, v1);
248     store(rdata2, v2);
249
250     for (int i = 0; i < GMX_SIMD_FLOAT_WIDTH; i++)
251     {
252         base[align * offset[i] + 0] += rdata0[i];
253         base[align * offset[i] + 1] += rdata1[i];
254         base[align * offset[i] + 2] += rdata2[i];
255     }
256 }
257
258 template<int align>
259 static inline void gmx_simdcall
260                    transposeScatterDecrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
261 {
262     alignas(GMX_SIMD_ALIGNMENT) float rdata0[GMX_SIMD_FLOAT_WIDTH];
263     alignas(GMX_SIMD_ALIGNMENT) float rdata1[GMX_SIMD_FLOAT_WIDTH];
264     alignas(GMX_SIMD_ALIGNMENT) float rdata2[GMX_SIMD_FLOAT_WIDTH];
265
266     store(rdata0, v0);
267     store(rdata1, v1);
268     store(rdata2, v2);
269
270     for (int i = 0; i < GMX_SIMD_FLOAT_WIDTH; i++)
271     {
272         base[align * offset[i] + 0] -= rdata0[i];
273         base[align * offset[i] + 1] -= rdata1[i];
274         base[align * offset[i] + 2] -= rdata2[i];
275     }
276 }
277
278 static inline void gmx_simdcall expandScalarsToTriplets(SimdFloat  scalar,
279                                                         SimdFloat* triplets0,
280                                                         SimdFloat* triplets1,
281                                                         SimdFloat* triplets2)
282 {
283     triplets0->simdInternal_ = _mm512_castsi512_ps(
284             _mm512_permutevar_epi32(_mm512_set_epi32(5, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1, 0, 0, 0),
285                                     _mm512_castps_si512(scalar.simdInternal_)));
286     triplets1->simdInternal_ = _mm512_castsi512_ps(_mm512_permutevar_epi32(
287             _mm512_set_epi32(10, 10, 9, 9, 9, 8, 8, 8, 7, 7, 7, 6, 6, 6, 5, 5),
288             _mm512_castps_si512(scalar.simdInternal_)));
289     triplets2->simdInternal_ = _mm512_castsi512_ps(_mm512_permutevar_epi32(
290             _mm512_set_epi32(15, 15, 15, 14, 14, 14, 13, 13, 13, 12, 12, 12, 11, 11, 11, 10),
291             _mm512_castps_si512(scalar.simdInternal_)));
292 }
293
294
295 static inline float gmx_simdcall reduceIncr4ReturnSum(float* m, SimdFloat v0, SimdFloat v1, SimdFloat v2, SimdFloat v3)
296 {
297     float  f;
298     __m512 t0, t1, t2, t3;
299
300     assert(std::size_t(m) % 16 == 0);
301
302     t0 = _mm512_add_ps(v0.simdInternal_, _mm512_swizzle_ps(v0.simdInternal_, _MM_SWIZ_REG_BADC));
303     t0 = _mm512_mask_add_ps(t0,
304                             _mm512_int2mask(0xCCCC),
305                             v2.simdInternal_,
306                             _mm512_swizzle_ps(v2.simdInternal_, _MM_SWIZ_REG_BADC));
307     t1 = _mm512_add_ps(v1.simdInternal_, _mm512_swizzle_ps(v1.simdInternal_, _MM_SWIZ_REG_BADC));
308     t1 = _mm512_mask_add_ps(t1,
309                             _mm512_int2mask(0xCCCC),
310                             v3.simdInternal_,
311                             _mm512_swizzle_ps(v3.simdInternal_, _MM_SWIZ_REG_BADC));
312     t2 = _mm512_add_ps(t0, _mm512_swizzle_ps(t0, _MM_SWIZ_REG_CDAB));
313     t2 = _mm512_mask_add_ps(t2, _mm512_int2mask(0xAAAA), t1, _mm512_swizzle_ps(t1, _MM_SWIZ_REG_CDAB));
314
315     t2 = _mm512_add_ps(t2, _mm512_permute4f128_ps(t2, _MM_PERM_BADC));
316     t2 = _mm512_add_ps(t2, _mm512_permute4f128_ps(t2, _MM_PERM_CDAB));
317
318     t0 = _mm512_mask_extload_ps(
319             _mm512_undefined_ps(), _mm512_int2mask(0xF), m, _MM_UPCONV_PS_NONE, _MM_BROADCAST_4X16, _MM_HINT_NONE);
320     t0 = _mm512_add_ps(t0, t2);
321     _mm512_mask_packstorelo_ps(m, _mm512_int2mask(0xF), t0);
322
323     t2 = _mm512_add_ps(t2, _mm512_swizzle_ps(t2, _MM_SWIZ_REG_BADC));
324     t2 = _mm512_add_ps(t2, _mm512_swizzle_ps(t2, _MM_SWIZ_REG_CDAB));
325
326     _mm512_mask_packstorelo_ps(&f, _mm512_mask2int(0x1), t2);
327     return f;
328 }
329
330 static inline SimdFloat gmx_simdcall loadDualHsimd(const float* m0, const float* m1)
331 {
332     assert(std::size_t(m0) % 32 == 0);
333     assert(std::size_t(m1) % 32 == 0);
334
335     return _mm512_castpd_ps(_mm512_mask_extload_pd(
336             _mm512_extload_pd(reinterpret_cast<const double*>(m0), _MM_UPCONV_PD_NONE, _MM_BROADCAST_4X8, _MM_HINT_NONE),
337             _mm512_int2mask(0xF0),
338             reinterpret_cast<const double*>(m1),
339             _MM_UPCONV_PD_NONE,
340             _MM_BROADCAST_4X8,
341             _MM_HINT_NONE));
342 }
343
344 static inline SimdFloat gmx_simdcall loadDuplicateHsimd(const float* m)
345 {
346     assert(std::size_t(m) % 32 == 0);
347
348     return _mm512_castpd_ps(_mm512_extload_pd(
349             reinterpret_cast<const double*>(m), _MM_UPCONV_PD_NONE, _MM_BROADCAST_4X8, _MM_HINT_NONE));
350 }
351
352 static inline SimdFloat gmx_simdcall loadU1DualHsimd(const float* m)
353 {
354     return _mm512_mask_extload_ps(_mm512_extload_ps(m, _MM_UPCONV_PS_NONE, _MM_BROADCAST_1X16, _MM_HINT_NONE),
355                                   _mm512_int2mask(0xFF00),
356                                   m + 1,
357                                   _MM_UPCONV_PS_NONE,
358                                   _MM_BROADCAST_1X16,
359                                   _MM_HINT_NONE);
360 }
361
362
363 static inline void gmx_simdcall storeDualHsimd(float* m0, float* m1, SimdFloat a)
364 {
365     __m512 t0;
366
367     assert(std::size_t(m0) % 32 == 0);
368     assert(std::size_t(m1) % 32 == 0);
369
370     _mm512_mask_packstorelo_ps(m0, _mm512_int2mask(0x00FF), a.simdInternal_);
371     _mm512_mask_packstorelo_ps(m1, _mm512_int2mask(0xFF00), a.simdInternal_);
372 }
373
374 static inline void gmx_simdcall incrDualHsimd(float* m0, float* m1, SimdFloat a)
375 {
376     assert(std::size_t(m0) % 32 == 0);
377     assert(std::size_t(m1) % 32 == 0);
378
379     __m512 x;
380
381     // Update lower half
382     x = _mm512_castpd_ps(_mm512_extload_pd(
383             reinterpret_cast<const double*>(m0), _MM_UPCONV_PD_NONE, _MM_BROADCAST_4X8, _MM_HINT_NONE));
384     x = _mm512_add_ps(x, a.simdInternal_);
385     _mm512_mask_packstorelo_ps(m0, _mm512_int2mask(0x00FF), x);
386
387     // Update upper half
388     x = _mm512_castpd_ps(_mm512_extload_pd(
389             reinterpret_cast<const double*>(m1), _MM_UPCONV_PD_NONE, _MM_BROADCAST_4X8, _MM_HINT_NONE));
390     x = _mm512_add_ps(x, a.simdInternal_);
391     _mm512_mask_packstorelo_ps(m1, _mm512_int2mask(0xFF00), x);
392 }
393
394 static inline void gmx_simdcall decr3Hsimd(float* m, SimdFloat a0, SimdFloat a1, SimdFloat a2)
395 {
396     assert(std::size_t(m) % 32 == 0);
397     decrHsimd(m, a0);
398     decrHsimd(m + GMX_SIMD_FLOAT_WIDTH / 2, a1);
399     decrHsimd(m + GMX_SIMD_FLOAT_WIDTH, a2);
400 }
401
402
403 template<int align>
404 static inline void gmx_simdcall gatherLoadTransposeHsimd(const float*       base0,
405                                                          const float*       base1,
406                                                          const std::int32_t offset[],
407                                                          SimdFloat*         v0,
408                                                          SimdFloat*         v1)
409 {
410     __m512i idx0, idx1, idx;
411     __m512  tmp1, tmp2;
412
413     assert(std::size_t(offset) % 32 == 0);
414     assert(std::size_t(base0) % 8 == 0);
415     assert(std::size_t(base1) % 8 == 0);
416     assert(std::size_t(align) % 2 == 0);
417
418     idx0 = _mm512_loadunpacklo_epi32(_mm512_undefined_epi32(), offset);
419
420     idx0 = _mm512_mullo_epi32(idx0, _mm512_set1_epi32(align));
421     idx1 = _mm512_add_epi32(idx0, _mm512_set1_epi32(1));
422
423     idx = _mm512_mask_permute4f128_epi32(idx0, _mm512_int2mask(0xFF00), idx1, _MM_PERM_BABA);
424
425     tmp1 = _mm512_i32gather_ps(idx, base0, sizeof(float));
426     tmp2 = _mm512_i32gather_ps(idx, base1, sizeof(float));
427
428     v0->simdInternal_ = _mm512_mask_permute4f128_ps(tmp1, _mm512_int2mask(0xFF00), tmp2, _MM_PERM_BABA);
429     v1->simdInternal_ = _mm512_mask_permute4f128_ps(tmp2, _mm512_int2mask(0x00FF), tmp1, _MM_PERM_DCDC);
430 }
431
432 static inline float gmx_simdcall reduceIncr4ReturnSumHsimd(float* m, SimdFloat v0, SimdFloat v1)
433 {
434     float  f;
435     __m512 t0, t1;
436
437     assert(std::size_t(m) % 32 == 0);
438
439     t0 = _mm512_add_ps(v0.simdInternal_, _mm512_swizzle_ps(v0.simdInternal_, _MM_SWIZ_REG_BADC));
440     t0 = _mm512_mask_add_ps(t0,
441                             _mm512_int2mask(0xCCCC),
442                             v1.simdInternal_,
443                             _mm512_swizzle_ps(v1.simdInternal_, _MM_SWIZ_REG_BADC));
444     t0 = _mm512_add_ps(t0, _mm512_swizzle_ps(t0, _MM_SWIZ_REG_CDAB));
445     t0 = _mm512_add_ps(t0, _mm512_castpd_ps(_mm512_swizzle_pd(_mm512_castps_pd(t0), _MM_SWIZ_REG_BADC)));
446     t0 = _mm512_mask_permute4f128_ps(t0, _mm512_int2mask(0xAAAA), t0, _MM_PERM_BADC);
447     t1 = _mm512_mask_extload_ps(
448             _mm512_undefined_ps(), _mm512_int2mask(0xF), m, _MM_UPCONV_PS_NONE, _MM_BROADCAST_4X16, _MM_HINT_NONE);
449     t1 = _mm512_add_ps(t1, t0);
450     _mm512_mask_packstorelo_ps(m, _mm512_int2mask(0xF), t1);
451
452     t0 = _mm512_add_ps(t0, _mm512_swizzle_ps(t0, _MM_SWIZ_REG_BADC));
453     t0 = _mm512_add_ps(t0, _mm512_swizzle_ps(t0, _MM_SWIZ_REG_CDAB));
454
455     _mm512_mask_packstorelo_ps(&f, _mm512_mask2int(0x1), t0);
456     return f;
457 }
458
459 } // namespace gmx
460
461 #endif // GMX_SIMD_IMPL_X86_MIC_UTIL_FLOAT_H