Apply clang-format-11
[alexxy/gromacs.git] / src / gromacs / simd / impl_x86_avx_512 / impl_x86_avx_512_util_double.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014-2018, The GROMACS development team.
5  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36
37 #ifndef GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
38 #define GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
39
40 #include "config.h"
41
42 #include <cassert>
43 #include <cstdint>
44
45 #include <immintrin.h>
46
47 #include "gromacs/utility/basedefinitions.h"
48
49 #include "impl_x86_avx_512_general.h"
50 #include "impl_x86_avx_512_simd_double.h"
51
52 namespace gmx
53 {
54
55 static const int c_simdBestPairAlignmentDouble = 2;
56
57 namespace
58 {
59 // Multiply function optimized for powers of 2, for which it is done by
60 // shifting. Currently up to 8 is accelerated. Could be accelerated for any
61 // number with a constexpr log2 function.
62 template<int n>
63 static inline SimdDInt32 fastMultiply(SimdDInt32 x)
64 {
65     if (n == 2)
66     {
67         return _mm256_slli_epi32(x.simdInternal_, 1);
68     }
69     else if (n == 4)
70     {
71         return _mm256_slli_epi32(x.simdInternal_, 2);
72     }
73     else if (n == 8)
74     {
75         return _mm256_slli_epi32(x.simdInternal_, 3);
76     }
77     else
78     {
79         return x * n;
80     }
81 }
82
83 template<int align>
84 static inline void gmx_simdcall gatherLoadBySimdIntTranspose(const double*, SimdDInt32)
85 {
86     // Nothing to do. Termination of recursion.
87 }
88
89 /* This is an internal helper function used by decr3Hsimd(...).
90  */
91 inline void gmx_simdcall decrHsimd(double* m, SimdDouble a)
92 {
93     __m256d t;
94
95     assert(std::size_t(m) % 32 == 0);
96
97     a.simdInternal_ = _mm512_add_pd(a.simdInternal_,
98                                     _mm512_shuffle_f64x2(a.simdInternal_, a.simdInternal_, 0xEE));
99     t               = _mm256_load_pd(m);
100     t               = _mm256_sub_pd(t, _mm512_castpd512_pd256(a.simdInternal_));
101     _mm256_store_pd(m, t);
102 }
103 } // namespace
104
105
106 template<int align, typename... Targs>
107 static inline void gmx_simdcall
108 gatherLoadBySimdIntTranspose(const double* base, SimdDInt32 offset, SimdDouble* v, Targs... Fargs)
109 {
110     if (align > 1)
111     {
112         offset = fastMultiply<align>(offset);
113     }
114     constexpr size_t scale = sizeof(double);
115     v->simdInternal_       = _mm512_i32gather_pd(offset.simdInternal_, base, scale);
116     gatherLoadBySimdIntTranspose<1>(base + 1, offset, Fargs...);
117 }
118
119 template<int align, typename... Targs>
120 static inline void gmx_simdcall
121 gatherLoadUBySimdIntTranspose(const double* base, SimdDInt32 offset, SimdDouble* v, Targs... Fargs)
122 {
123     gatherLoadBySimdIntTranspose<align>(base, offset, v, Fargs...);
124 }
125
126 template<int align, typename... Targs>
127 static inline void gmx_simdcall
128 gatherLoadTranspose(const double* base, const std::int32_t offset[], SimdDouble* v, Targs... Fargs)
129 {
130     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdDInt32Tag()), v, Fargs...);
131 }
132
133 template<int align, typename... Targs>
134 static inline void gmx_simdcall
135 gatherLoadUTranspose(const double* base, const std::int32_t offset[], SimdDouble* v, Targs... Fargs)
136 {
137     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdDInt32Tag()), v, Fargs...);
138 }
139
140 template<int align>
141 static inline void gmx_simdcall transposeScatterStoreU(double*            base,
142                                                        const std::int32_t offset[],
143                                                        SimdDouble         v0,
144                                                        SimdDouble         v1,
145                                                        SimdDouble         v2)
146 {
147     SimdDInt32 simdoffset = simdLoad(offset, SimdDInt32Tag());
148
149     if (align > 1)
150     {
151         simdoffset = fastMultiply<align>(simdoffset);
152         ;
153     }
154     constexpr size_t scale = sizeof(double);
155     _mm512_i32scatter_pd(base, simdoffset.simdInternal_, v0.simdInternal_, scale);
156     _mm512_i32scatter_pd(&(base[1]), simdoffset.simdInternal_, v1.simdInternal_, scale);
157     _mm512_i32scatter_pd(&(base[2]), simdoffset.simdInternal_, v2.simdInternal_, scale);
158 }
159
160 template<int align>
161 static inline void gmx_simdcall
162 transposeScatterIncrU(double* base, const std::int32_t offset[], SimdDouble v0, SimdDouble v1, SimdDouble v2)
163 {
164     __m512d                                  t[4], t5, t6, t7, t8;
165     alignas(GMX_SIMD_ALIGNMENT) std::int64_t o[8];
166     // TODO: should use fastMultiply
167     _mm512_store_epi64(o,
168                        _mm512_cvtepi32_epi64(_mm256_mullo_epi32(
169                                _mm256_load_si256((const __m256i*)(offset)), _mm256_set1_epi32(align))));
170     t5   = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
171     t6   = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
172     t7   = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
173     t8   = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
174     t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
175     t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
176     t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
177     t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
178     if (align < 4)
179     {
180         for (int i = 0; i < 4; i++)
181         {
182             _mm512_mask_storeu_pd(base + o[0 + i],
183                                   avx512Int2Mask(7),
184                                   _mm512_castpd256_pd512(_mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]),
185                                                                        _mm512_castpd512_pd256(t[i]))));
186             _mm512_mask_storeu_pd(
187                     base + o[4 + i],
188                     avx512Int2Mask(7),
189                     _mm512_castpd256_pd512(_mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]),
190                                                          _mm512_extractf64x4_pd(t[i], 1))));
191         }
192     }
193     else
194     {
195         if (align % 4 == 0)
196         {
197             for (int i = 0; i < 4; i++)
198             {
199                 _mm256_store_pd(
200                         base + o[0 + i],
201                         _mm256_add_pd(_mm256_load_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
202                 _mm256_store_pd(base + o[4 + i],
203                                 _mm256_add_pd(_mm256_load_pd(base + o[4 + i]),
204                                               _mm512_extractf64x4_pd(t[i], 1)));
205             }
206         }
207         else
208         {
209             for (int i = 0; i < 4; i++)
210             {
211                 _mm256_storeu_pd(
212                         base + o[0 + i],
213                         _mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
214                 _mm256_storeu_pd(base + o[4 + i],
215                                  _mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]),
216                                                _mm512_extractf64x4_pd(t[i], 1)));
217             }
218         }
219     }
220 }
221
222 template<int align>
223 static inline void gmx_simdcall
224 transposeScatterDecrU(double* base, const std::int32_t offset[], SimdDouble v0, SimdDouble v1, SimdDouble v2)
225 {
226     __m512d                                  t[4], t5, t6, t7, t8;
227     alignas(GMX_SIMD_ALIGNMENT) std::int64_t o[8];
228     // TODO: should use fastMultiply
229     _mm512_store_epi64(o,
230                        _mm512_cvtepi32_epi64(_mm256_mullo_epi32(
231                                _mm256_load_si256((const __m256i*)(offset)), _mm256_set1_epi32(align))));
232     t5   = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
233     t6   = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
234     t7   = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
235     t8   = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
236     t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
237     t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
238     t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
239     t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
240     if (align < 4)
241     {
242         for (int i = 0; i < 4; i++)
243         {
244             _mm512_mask_storeu_pd(base + o[0 + i],
245                                   avx512Int2Mask(7),
246                                   _mm512_castpd256_pd512(_mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]),
247                                                                        _mm512_castpd512_pd256(t[i]))));
248             _mm512_mask_storeu_pd(
249                     base + o[4 + i],
250                     avx512Int2Mask(7),
251                     _mm512_castpd256_pd512(_mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]),
252                                                          _mm512_extractf64x4_pd(t[i], 1))));
253         }
254     }
255     else
256     {
257         if (align % 4 == 0)
258         {
259             for (int i = 0; i < 4; i++)
260             {
261                 _mm256_store_pd(
262                         base + o[0 + i],
263                         _mm256_sub_pd(_mm256_load_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
264                 _mm256_store_pd(base + o[4 + i],
265                                 _mm256_sub_pd(_mm256_load_pd(base + o[4 + i]),
266                                               _mm512_extractf64x4_pd(t[i], 1)));
267             }
268         }
269         else
270         {
271             for (int i = 0; i < 4; i++)
272             {
273                 _mm256_storeu_pd(
274                         base + o[0 + i],
275                         _mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
276                 _mm256_storeu_pd(base + o[4 + i],
277                                  _mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]),
278                                                _mm512_extractf64x4_pd(t[i], 1)));
279             }
280         }
281     }
282 }
283
284 static inline void gmx_simdcall expandScalarsToTriplets(SimdDouble  scalar,
285                                                         SimdDouble* triplets0,
286                                                         SimdDouble* triplets1,
287                                                         SimdDouble* triplets2)
288 {
289     triplets0->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(
290             _mm512_set_epi32(5, 4, 5, 4, 3, 2, 3, 2, 3, 2, 1, 0, 1, 0, 1, 0),
291             _mm512_castpd_si512(scalar.simdInternal_)));
292     triplets1->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(
293             _mm512_set_epi32(11, 10, 9, 8, 9, 8, 9, 8, 7, 6, 7, 6, 7, 6, 5, 4),
294             _mm512_castpd_si512(scalar.simdInternal_)));
295     triplets2->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(
296             _mm512_set_epi32(15, 14, 15, 14, 15, 14, 13, 12, 13, 12, 13, 12, 11, 10, 11, 10),
297             _mm512_castpd_si512(scalar.simdInternal_)));
298 }
299
300
301 static inline double gmx_simdcall
302 reduceIncr4ReturnSum(double* m, SimdDouble v0, SimdDouble v1, SimdDouble v2, SimdDouble v3)
303 {
304     __m512d t0, t2;
305     __m256d t3, t4;
306
307     assert(std::size_t(m) % 32 == 0);
308
309     t0 = _mm512_add_pd(v0.simdInternal_, _mm512_permute_pd(v0.simdInternal_, 0x55));
310     t2 = _mm512_add_pd(v2.simdInternal_, _mm512_permute_pd(v2.simdInternal_, 0x55));
311     t0 = _mm512_mask_add_pd(
312             t0, avx512Int2Mask(0xAA), v1.simdInternal_, _mm512_permute_pd(v1.simdInternal_, 0x55));
313     t2 = _mm512_mask_add_pd(
314             t2, avx512Int2Mask(0xAA), v3.simdInternal_, _mm512_permute_pd(v3.simdInternal_, 0x55));
315     t0 = _mm512_add_pd(t0, _mm512_shuffle_f64x2(t0, t0, 0x4E));
316     t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xF0), t2, _mm512_shuffle_f64x2(t2, t2, 0x4E));
317     t0 = _mm512_add_pd(t0, _mm512_shuffle_f64x2(t0, t0, 0xB1));
318     t0 = _mm512_mask_shuffle_f64x2(t0, avx512Int2Mask(0x0C), t0, t0, 0xEE);
319
320     t3 = _mm512_castpd512_pd256(t0);
321     t4 = _mm256_load_pd(m);
322     t4 = _mm256_add_pd(t4, t3);
323     _mm256_store_pd(m, t4);
324
325     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0x4E));
326     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
327
328     return _mm_cvtsd_f64(_mm512_castpd512_pd128(t0));
329 }
330
331 static inline SimdDouble gmx_simdcall loadDualHsimd(const double* m0, const double* m1)
332 {
333     assert(std::size_t(m0) % 32 == 0);
334     assert(std::size_t(m1) % 32 == 0);
335
336     return { _mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_load_pd(m0)), _mm256_load_pd(m1), 1) };
337 }
338
339 static inline SimdDouble gmx_simdcall loadDuplicateHsimd(const double* m)
340 {
341     assert(std::size_t(m) % 32 == 0);
342
343     return { _mm512_broadcast_f64x4(_mm256_load_pd(m)) };
344 }
345
346 static inline SimdDouble gmx_simdcall loadU1DualHsimd(const double* m)
347 {
348     return { _mm512_insertf64x4(
349             _mm512_broadcastsd_pd(_mm_load_sd(m)), _mm256_broadcastsd_pd(_mm_load_sd(m + 1)), 1) };
350 }
351
352
353 static inline void gmx_simdcall storeDualHsimd(double* m0, double* m1, SimdDouble a)
354 {
355     assert(std::size_t(m0) % 32 == 0);
356     assert(std::size_t(m1) % 32 == 0);
357
358     _mm256_store_pd(m0, _mm512_castpd512_pd256(a.simdInternal_));
359     _mm256_store_pd(m1, _mm512_extractf64x4_pd(a.simdInternal_, 1));
360 }
361
362 static inline void gmx_simdcall incrDualHsimd(double* m0, double* m1, SimdDouble a)
363 {
364     assert(std::size_t(m0) % 32 == 0);
365     assert(std::size_t(m1) % 32 == 0);
366
367     __m256d x;
368
369     // Lower half
370     x = _mm256_load_pd(m0);
371     x = _mm256_add_pd(x, _mm512_castpd512_pd256(a.simdInternal_));
372     _mm256_store_pd(m0, x);
373
374     // Upper half
375     x = _mm256_load_pd(m1);
376     x = _mm256_add_pd(x, _mm512_extractf64x4_pd(a.simdInternal_, 1));
377     _mm256_store_pd(m1, x);
378 }
379
380 static inline void gmx_simdcall decr3Hsimd(double* m, SimdDouble a0, SimdDouble a1, SimdDouble a2)
381 {
382     decrHsimd(m, a0);
383     decrHsimd(m + GMX_SIMD_DOUBLE_WIDTH / 2, a1);
384     decrHsimd(m + GMX_SIMD_DOUBLE_WIDTH, a2);
385 }
386
387 template<int align>
388 static inline void gmx_simdcall gatherLoadTransposeHsimd(const double*      base0,
389                                                          const double*      base1,
390                                                          const std::int32_t offset[],
391                                                          SimdDouble*        v0,
392                                                          SimdDouble*        v1)
393 {
394     __m128i idx0, idx1;
395     __m256i idx;
396     __m512d tmp1, tmp2;
397
398     assert(std::size_t(offset) % 16 == 0);
399     assert(std::size_t(base0) % 16 == 0);
400     assert(std::size_t(base1) % 16 == 0);
401
402     idx0 = _mm_load_si128(reinterpret_cast<const __m128i*>(offset));
403
404     static_assert(align == 2 || align == 4, "If more are needed use fastMultiply");
405     idx0 = _mm_slli_epi32(idx0, align == 2 ? 1 : 2);
406
407     idx1 = _mm_add_epi32(idx0, _mm_set1_epi32(1));
408
409     idx = _mm256_inserti128_si256(_mm256_castsi128_si256(idx0), idx1, 1);
410
411     constexpr size_t scale = sizeof(double);
412     tmp1 = _mm512_i32gather_pd(idx, base0, scale); // TODO: Might be faster to use invidual loads
413     tmp2 = _mm512_i32gather_pd(idx, base1, scale);
414
415     v0->simdInternal_ = _mm512_shuffle_f64x2(tmp1, tmp2, 0x44);
416     v1->simdInternal_ = _mm512_shuffle_f64x2(tmp1, tmp2, 0xEE);
417 }
418
419 static inline double gmx_simdcall reduceIncr4ReturnSumHsimd(double* m, SimdDouble v0, SimdDouble v1)
420 {
421     __m512d t0;
422     __m256d t2, t3;
423
424     assert(std::size_t(m) % 32 == 0);
425
426     t0 = _mm512_add_pd(v0.simdInternal_, _mm512_permutex_pd(v0.simdInternal_, 0x4E));
427     t0 = _mm512_mask_add_pd(
428             t0, avx512Int2Mask(0xCC), v1.simdInternal_, _mm512_permutex_pd(v1.simdInternal_, 0x4E));
429     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
430     t0 = _mm512_mask_shuffle_f64x2(t0, avx512Int2Mask(0xAA), t0, t0, 0xEE);
431
432     t2 = _mm512_castpd512_pd256(t0);
433     t3 = _mm256_load_pd(m);
434     t3 = _mm256_add_pd(t3, t2);
435     _mm256_store_pd(m, t3);
436
437     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0x4E));
438     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
439
440     return _mm_cvtsd_f64(_mm512_castpd512_pd128(t0));
441 }
442
443 static inline SimdDouble gmx_simdcall loadU4NOffset(const double* m, int offset)
444 {
445     return { _mm512_insertf64x4(
446             _mm512_castpd256_pd512(_mm256_loadu_pd(m)), _mm256_loadu_pd(m + offset), 1) };
447 }
448
449 } // namespace gmx
450
451 #endif // GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H