Apply clang-format to source tree
[alexxy/gromacs.git] / src / gromacs / simd / impl_x86_avx_512 / impl_x86_avx_512_util_double.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014-2018, The GROMACS development team.
5  * Copyright (c) 2019, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36
37 #ifndef GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
38 #define GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
39
40 #include "config.h"
41
42 #include <cassert>
43 #include <cstdint>
44
45 #include <immintrin.h>
46
47 #include "gromacs/utility/basedefinitions.h"
48
49 #include "impl_x86_avx_512_general.h"
50 #include "impl_x86_avx_512_simd_double.h"
51
52 namespace gmx
53 {
54
55 static const int c_simdBestPairAlignmentDouble = 2;
56
57 namespace
58 {
59 // Multiply function optimized for powers of 2, for which it is done by
60 // shifting. Currently up to 8 is accelerated. Could be accelerated for any
61 // number with a constexpr log2 function.
62 template<int n>
63 SimdDInt32 fastMultiply(SimdDInt32 x)
64 {
65     if (n == 2)
66     {
67         return _mm256_slli_epi32(x.simdInternal_, 1);
68     }
69     else if (n == 4)
70     {
71         return _mm256_slli_epi32(x.simdInternal_, 2);
72     }
73     else if (n == 8)
74     {
75         return _mm256_slli_epi32(x.simdInternal_, 3);
76     }
77     else
78     {
79         return x * n;
80     }
81 }
82
83 template<int align>
84 static inline void gmx_simdcall gatherLoadBySimdIntTranspose(const double*, SimdDInt32)
85 {
86     // Nothing to do. Termination of recursion.
87 }
88 } // namespace
89
90
91 template<int align, typename... Targs>
92 static inline void gmx_simdcall
93                    gatherLoadBySimdIntTranspose(const double* base, SimdDInt32 offset, SimdDouble* v, Targs... Fargs)
94 {
95     if (align > 1)
96     {
97         offset = fastMultiply<align>(offset);
98     }
99     constexpr size_t scale = sizeof(double);
100     v->simdInternal_       = _mm512_i32gather_pd(offset.simdInternal_, base, scale);
101     gatherLoadBySimdIntTranspose<1>(base + 1, offset, Fargs...);
102 }
103
104 template<int align, typename... Targs>
105 static inline void gmx_simdcall
106                    gatherLoadUBySimdIntTranspose(const double* base, SimdDInt32 offset, SimdDouble* v, Targs... Fargs)
107 {
108     gatherLoadBySimdIntTranspose<align>(base, offset, v, Fargs...);
109 }
110
111 template<int align, typename... Targs>
112 static inline void gmx_simdcall
113                    gatherLoadTranspose(const double* base, const std::int32_t offset[], SimdDouble* v, Targs... Fargs)
114 {
115     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdDInt32Tag()), v, Fargs...);
116 }
117
118 template<int align, typename... Targs>
119 static inline void gmx_simdcall
120                    gatherLoadUTranspose(const double* base, const std::int32_t offset[], SimdDouble* v, Targs... Fargs)
121 {
122     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdDInt32Tag()), v, Fargs...);
123 }
124
125 template<int align>
126 static inline void gmx_simdcall transposeScatterStoreU(double*            base,
127                                                        const std::int32_t offset[],
128                                                        SimdDouble         v0,
129                                                        SimdDouble         v1,
130                                                        SimdDouble         v2)
131 {
132     SimdDInt32 simdoffset = simdLoad(offset, SimdDInt32Tag());
133
134     if (align > 1)
135     {
136         simdoffset = fastMultiply<align>(simdoffset);
137         ;
138     }
139     constexpr size_t scale = sizeof(double);
140     _mm512_i32scatter_pd(base, simdoffset.simdInternal_, v0.simdInternal_, scale);
141     _mm512_i32scatter_pd(&(base[1]), simdoffset.simdInternal_, v1.simdInternal_, scale);
142     _mm512_i32scatter_pd(&(base[2]), simdoffset.simdInternal_, v2.simdInternal_, scale);
143 }
144
145 template<int align>
146 static inline void gmx_simdcall
147                    transposeScatterIncrU(double* base, const std::int32_t offset[], SimdDouble v0, SimdDouble v1, SimdDouble v2)
148 {
149     __m512d                                  t[4], t5, t6, t7, t8;
150     alignas(GMX_SIMD_ALIGNMENT) std::int64_t o[8];
151     // TODO: should use fastMultiply
152     _mm512_store_epi64(o, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(
153                                   _mm256_load_si256((const __m256i*)(offset)), _mm256_set1_epi32(align))));
154     t5   = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
155     t6   = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
156     t7   = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
157     t8   = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
158     t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
159     t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
160     t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
161     t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
162     if (align < 4)
163     {
164         for (int i = 0; i < 4; i++)
165         {
166             _mm512_mask_storeu_pd(base + o[0 + i], avx512Int2Mask(7),
167                                   _mm512_castpd256_pd512(_mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]),
168                                                                        _mm512_castpd512_pd256(t[i]))));
169             _mm512_mask_storeu_pd(
170                     base + o[4 + i], avx512Int2Mask(7),
171                     _mm512_castpd256_pd512(_mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]),
172                                                          _mm512_extractf64x4_pd(t[i], 1))));
173         }
174     }
175     else
176     {
177         if (align % 4 == 0)
178         {
179             for (int i = 0; i < 4; i++)
180             {
181                 _mm256_store_pd(base + o[0 + i], _mm256_add_pd(_mm256_load_pd(base + o[0 + i]),
182                                                                _mm512_castpd512_pd256(t[i])));
183                 _mm256_store_pd(base + o[4 + i], _mm256_add_pd(_mm256_load_pd(base + o[4 + i]),
184                                                                _mm512_extractf64x4_pd(t[i], 1)));
185             }
186         }
187         else
188         {
189             for (int i = 0; i < 4; i++)
190             {
191                 _mm256_storeu_pd(base + o[0 + i], _mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]),
192                                                                 _mm512_castpd512_pd256(t[i])));
193                 _mm256_storeu_pd(base + o[4 + i], _mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]),
194                                                                 _mm512_extractf64x4_pd(t[i], 1)));
195             }
196         }
197     }
198 }
199
200 template<int align>
201 static inline void gmx_simdcall
202                    transposeScatterDecrU(double* base, const std::int32_t offset[], SimdDouble v0, SimdDouble v1, SimdDouble v2)
203 {
204     __m512d                                  t[4], t5, t6, t7, t8;
205     alignas(GMX_SIMD_ALIGNMENT) std::int64_t o[8];
206     // TODO: should use fastMultiply
207     _mm512_store_epi64(o, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(
208                                   _mm256_load_si256((const __m256i*)(offset)), _mm256_set1_epi32(align))));
209     t5   = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
210     t6   = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
211     t7   = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
212     t8   = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
213     t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
214     t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
215     t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
216     t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
217     if (align < 4)
218     {
219         for (int i = 0; i < 4; i++)
220         {
221             _mm512_mask_storeu_pd(base + o[0 + i], avx512Int2Mask(7),
222                                   _mm512_castpd256_pd512(_mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]),
223                                                                        _mm512_castpd512_pd256(t[i]))));
224             _mm512_mask_storeu_pd(
225                     base + o[4 + i], avx512Int2Mask(7),
226                     _mm512_castpd256_pd512(_mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]),
227                                                          _mm512_extractf64x4_pd(t[i], 1))));
228         }
229     }
230     else
231     {
232         if (align % 4 == 0)
233         {
234             for (int i = 0; i < 4; i++)
235             {
236                 _mm256_store_pd(base + o[0 + i], _mm256_sub_pd(_mm256_load_pd(base + o[0 + i]),
237                                                                _mm512_castpd512_pd256(t[i])));
238                 _mm256_store_pd(base + o[4 + i], _mm256_sub_pd(_mm256_load_pd(base + o[4 + i]),
239                                                                _mm512_extractf64x4_pd(t[i], 1)));
240             }
241         }
242         else
243         {
244             for (int i = 0; i < 4; i++)
245             {
246                 _mm256_storeu_pd(base + o[0 + i], _mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]),
247                                                                 _mm512_castpd512_pd256(t[i])));
248                 _mm256_storeu_pd(base + o[4 + i], _mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]),
249                                                                 _mm512_extractf64x4_pd(t[i], 1)));
250             }
251         }
252     }
253 }
254
255 static inline void gmx_simdcall expandScalarsToTriplets(SimdDouble  scalar,
256                                                         SimdDouble* triplets0,
257                                                         SimdDouble* triplets1,
258                                                         SimdDouble* triplets2)
259 {
260     triplets0->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(
261             _mm512_set_epi32(5, 4, 5, 4, 3, 2, 3, 2, 3, 2, 1, 0, 1, 0, 1, 0),
262             _mm512_castpd_si512(scalar.simdInternal_)));
263     triplets1->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(
264             _mm512_set_epi32(11, 10, 9, 8, 9, 8, 9, 8, 7, 6, 7, 6, 7, 6, 5, 4),
265             _mm512_castpd_si512(scalar.simdInternal_)));
266     triplets2->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(
267             _mm512_set_epi32(15, 14, 15, 14, 15, 14, 13, 12, 13, 12, 13, 12, 11, 10, 11, 10),
268             _mm512_castpd_si512(scalar.simdInternal_)));
269 }
270
271
272 static inline double gmx_simdcall
273                      reduceIncr4ReturnSum(double* m, SimdDouble v0, SimdDouble v1, SimdDouble v2, SimdDouble v3)
274 {
275     __m512d t0, t2;
276     __m256d t3, t4;
277
278     assert(std::size_t(m) % 32 == 0);
279
280     t0 = _mm512_add_pd(v0.simdInternal_, _mm512_permute_pd(v0.simdInternal_, 0x55));
281     t2 = _mm512_add_pd(v2.simdInternal_, _mm512_permute_pd(v2.simdInternal_, 0x55));
282     t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xAA), v1.simdInternal_,
283                             _mm512_permute_pd(v1.simdInternal_, 0x55));
284     t2 = _mm512_mask_add_pd(t2, avx512Int2Mask(0xAA), v3.simdInternal_,
285                             _mm512_permute_pd(v3.simdInternal_, 0x55));
286     t0 = _mm512_add_pd(t0, _mm512_shuffle_f64x2(t0, t0, 0x4E));
287     t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xF0), t2, _mm512_shuffle_f64x2(t2, t2, 0x4E));
288     t0 = _mm512_add_pd(t0, _mm512_shuffle_f64x2(t0, t0, 0xB1));
289     t0 = _mm512_mask_shuffle_f64x2(t0, avx512Int2Mask(0x0C), t0, t0, 0xEE);
290
291     t3 = _mm512_castpd512_pd256(t0);
292     t4 = _mm256_load_pd(m);
293     t4 = _mm256_add_pd(t4, t3);
294     _mm256_store_pd(m, t4);
295
296     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0x4E));
297     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
298
299     return _mm_cvtsd_f64(_mm512_castpd512_pd128(t0));
300 }
301
302 static inline SimdDouble gmx_simdcall loadDualHsimd(const double* m0, const double* m1)
303 {
304     assert(std::size_t(m0) % 32 == 0);
305     assert(std::size_t(m1) % 32 == 0);
306
307     return { _mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_load_pd(m0)), _mm256_load_pd(m1), 1) };
308 }
309
310 static inline SimdDouble gmx_simdcall loadDuplicateHsimd(const double* m)
311 {
312     assert(std::size_t(m) % 32 == 0);
313
314     return { _mm512_broadcast_f64x4(_mm256_load_pd(m)) };
315 }
316
317 static inline SimdDouble gmx_simdcall loadU1DualHsimd(const double* m)
318 {
319     return { _mm512_insertf64x4(_mm512_broadcastsd_pd(_mm_load_sd(m)),
320                                 _mm256_broadcastsd_pd(_mm_load_sd(m + 1)), 1) };
321 }
322
323
324 static inline void gmx_simdcall storeDualHsimd(double* m0, double* m1, SimdDouble a)
325 {
326     assert(std::size_t(m0) % 32 == 0);
327     assert(std::size_t(m1) % 32 == 0);
328
329     _mm256_store_pd(m0, _mm512_castpd512_pd256(a.simdInternal_));
330     _mm256_store_pd(m1, _mm512_extractf64x4_pd(a.simdInternal_, 1));
331 }
332
333 static inline void gmx_simdcall incrDualHsimd(double* m0, double* m1, SimdDouble a)
334 {
335     assert(std::size_t(m0) % 32 == 0);
336     assert(std::size_t(m1) % 32 == 0);
337
338     __m256d x;
339
340     // Lower half
341     x = _mm256_load_pd(m0);
342     x = _mm256_add_pd(x, _mm512_castpd512_pd256(a.simdInternal_));
343     _mm256_store_pd(m0, x);
344
345     // Upper half
346     x = _mm256_load_pd(m1);
347     x = _mm256_add_pd(x, _mm512_extractf64x4_pd(a.simdInternal_, 1));
348     _mm256_store_pd(m1, x);
349 }
350
351 static inline void gmx_simdcall decrHsimd(double* m, SimdDouble a)
352 {
353     __m256d t;
354
355     assert(std::size_t(m) % 32 == 0);
356
357     a.simdInternal_ = _mm512_add_pd(a.simdInternal_,
358                                     _mm512_shuffle_f64x2(a.simdInternal_, a.simdInternal_, 0xEE));
359     t               = _mm256_load_pd(m);
360     t               = _mm256_sub_pd(t, _mm512_castpd512_pd256(a.simdInternal_));
361     _mm256_store_pd(m, t);
362 }
363
364
365 template<int align>
366 static inline void gmx_simdcall gatherLoadTransposeHsimd(const double*      base0,
367                                                          const double*      base1,
368                                                          const std::int32_t offset[],
369                                                          SimdDouble*        v0,
370                                                          SimdDouble*        v1)
371 {
372     __m128i idx0, idx1;
373     __m256i idx;
374     __m512d tmp1, tmp2;
375
376     assert(std::size_t(offset) % 16 == 0);
377     assert(std::size_t(base0) % 16 == 0);
378     assert(std::size_t(base1) % 16 == 0);
379
380     idx0 = _mm_load_si128(reinterpret_cast<const __m128i*>(offset));
381
382     static_assert(align == 2 || align == 4, "If more are needed use fastMultiply");
383     idx0 = _mm_slli_epi32(idx0, align == 2 ? 1 : 2);
384
385     idx1 = _mm_add_epi32(idx0, _mm_set1_epi32(1));
386
387     idx = _mm256_inserti128_si256(_mm256_castsi128_si256(idx0), idx1, 1);
388
389     constexpr size_t scale = sizeof(double);
390     tmp1 = _mm512_i32gather_pd(idx, base0, scale); // TODO: Might be faster to use invidual loads
391     tmp2 = _mm512_i32gather_pd(idx, base1, scale);
392
393     v0->simdInternal_ = _mm512_shuffle_f64x2(tmp1, tmp2, 0x44);
394     v1->simdInternal_ = _mm512_shuffle_f64x2(tmp1, tmp2, 0xEE);
395 }
396
397 static inline double gmx_simdcall reduceIncr4ReturnSumHsimd(double* m, SimdDouble v0, SimdDouble v1)
398 {
399     __m512d t0;
400     __m256d t2, t3;
401
402     assert(std::size_t(m) % 32 == 0);
403
404     t0 = _mm512_add_pd(v0.simdInternal_, _mm512_permutex_pd(v0.simdInternal_, 0x4E));
405     t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xCC), v1.simdInternal_,
406                             _mm512_permutex_pd(v1.simdInternal_, 0x4E));
407     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
408     t0 = _mm512_mask_shuffle_f64x2(t0, avx512Int2Mask(0xAA), t0, t0, 0xEE);
409
410     t2 = _mm512_castpd512_pd256(t0);
411     t3 = _mm256_load_pd(m);
412     t3 = _mm256_add_pd(t3, t2);
413     _mm256_store_pd(m, t3);
414
415     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0x4E));
416     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
417
418     return _mm_cvtsd_f64(_mm512_castpd512_pd128(t0));
419 }
420
421 static inline SimdDouble gmx_simdcall loadU4NOffset(const double* m, int offset)
422 {
423     return { _mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_loadu_pd(m)),
424                                 _mm256_loadu_pd(m + offset), 1) };
425 }
426
427 } // namespace gmx
428
429 #endif // GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H