dac65b097816f658f7b1b4b4ccca6b47c8215172
[alexxy/gromacs.git] / src / gromacs / simd / impl_x86_avx_512 / impl_x86_avx_512_util_double.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014-2018, The GROMACS development team.
5  * Copyright (c) 2019,2020, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36
37 #ifndef GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
38 #define GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
39
40 #include "config.h"
41
42 #include <cassert>
43 #include <cstdint>
44
45 #include <immintrin.h>
46
47 #include "gromacs/utility/basedefinitions.h"
48
49 #include "impl_x86_avx_512_general.h"
50 #include "impl_x86_avx_512_simd_double.h"
51
52 namespace gmx
53 {
54
55 static const int c_simdBestPairAlignmentDouble = 2;
56
57 namespace
58 {
59 // Multiply function optimized for powers of 2, for which it is done by
60 // shifting. Currently up to 8 is accelerated. Could be accelerated for any
61 // number with a constexpr log2 function.
62 template<int n>
63 static inline SimdDInt32 fastMultiply(SimdDInt32 x)
64 {
65     if (n == 2)
66     {
67         return _mm256_slli_epi32(x.simdInternal_, 1);
68     }
69     else if (n == 4)
70     {
71         return _mm256_slli_epi32(x.simdInternal_, 2);
72     }
73     else if (n == 8)
74     {
75         return _mm256_slli_epi32(x.simdInternal_, 3);
76     }
77     else
78     {
79         return x * n;
80     }
81 }
82
83 template<int align>
84 static inline void gmx_simdcall gatherLoadBySimdIntTranspose(const double*, SimdDInt32)
85 {
86     // Nothing to do. Termination of recursion.
87 }
88
89 /* This is an internal helper function used by decr3Hsimd(...).
90  */
91 inline void gmx_simdcall decrHsimd(double* m, SimdDouble a)
92 {
93     __m256d t;
94
95     assert(std::size_t(m) % 32 == 0);
96
97     a.simdInternal_ = _mm512_add_pd(a.simdInternal_,
98                                     _mm512_shuffle_f64x2(a.simdInternal_, a.simdInternal_, 0xEE));
99     t               = _mm256_load_pd(m);
100     t               = _mm256_sub_pd(t, _mm512_castpd512_pd256(a.simdInternal_));
101     _mm256_store_pd(m, t);
102 }
103 } // namespace
104
105
106 template<int align, typename... Targs>
107 static inline void gmx_simdcall
108                    gatherLoadBySimdIntTranspose(const double* base, SimdDInt32 offset, SimdDouble* v, Targs... Fargs)
109 {
110     if (align > 1)
111     {
112         offset = fastMultiply<align>(offset);
113     }
114     constexpr size_t scale = sizeof(double);
115     v->simdInternal_       = _mm512_i32gather_pd(offset.simdInternal_, base, scale);
116     gatherLoadBySimdIntTranspose<1>(base + 1, offset, Fargs...);
117 }
118
119 template<int align, typename... Targs>
120 static inline void gmx_simdcall
121                    gatherLoadUBySimdIntTranspose(const double* base, SimdDInt32 offset, SimdDouble* v, Targs... Fargs)
122 {
123     gatherLoadBySimdIntTranspose<align>(base, offset, v, Fargs...);
124 }
125
126 template<int align, typename... Targs>
127 static inline void gmx_simdcall
128                    gatherLoadTranspose(const double* base, const std::int32_t offset[], SimdDouble* v, Targs... Fargs)
129 {
130     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdDInt32Tag()), v, Fargs...);
131 }
132
133 template<int align, typename... Targs>
134 static inline void gmx_simdcall
135                    gatherLoadUTranspose(const double* base, const std::int32_t offset[], SimdDouble* v, Targs... Fargs)
136 {
137     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdDInt32Tag()), v, Fargs...);
138 }
139
140 template<int align>
141 static inline void gmx_simdcall transposeScatterStoreU(double*            base,
142                                                        const std::int32_t offset[],
143                                                        SimdDouble         v0,
144                                                        SimdDouble         v1,
145                                                        SimdDouble         v2)
146 {
147     SimdDInt32 simdoffset = simdLoad(offset, SimdDInt32Tag());
148
149     if (align > 1)
150     {
151         simdoffset = fastMultiply<align>(simdoffset);
152         ;
153     }
154     constexpr size_t scale = sizeof(double);
155     _mm512_i32scatter_pd(base, simdoffset.simdInternal_, v0.simdInternal_, scale);
156     _mm512_i32scatter_pd(&(base[1]), simdoffset.simdInternal_, v1.simdInternal_, scale);
157     _mm512_i32scatter_pd(&(base[2]), simdoffset.simdInternal_, v2.simdInternal_, scale);
158 }
159
160 template<int align>
161 static inline void gmx_simdcall
162                    transposeScatterIncrU(double* base, const std::int32_t offset[], SimdDouble v0, SimdDouble v1, SimdDouble v2)
163 {
164     __m512d                                  t[4], t5, t6, t7, t8;
165     alignas(GMX_SIMD_ALIGNMENT) std::int64_t o[8];
166     // TODO: should use fastMultiply
167     _mm512_store_epi64(o, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(
168                                   _mm256_load_si256((const __m256i*)(offset)), _mm256_set1_epi32(align))));
169     t5   = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
170     t6   = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
171     t7   = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
172     t8   = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
173     t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
174     t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
175     t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
176     t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
177     if (align < 4)
178     {
179         for (int i = 0; i < 4; i++)
180         {
181             _mm512_mask_storeu_pd(base + o[0 + i], avx512Int2Mask(7),
182                                   _mm512_castpd256_pd512(_mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]),
183                                                                        _mm512_castpd512_pd256(t[i]))));
184             _mm512_mask_storeu_pd(
185                     base + o[4 + i], avx512Int2Mask(7),
186                     _mm512_castpd256_pd512(_mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]),
187                                                          _mm512_extractf64x4_pd(t[i], 1))));
188         }
189     }
190     else
191     {
192         if (align % 4 == 0)
193         {
194             for (int i = 0; i < 4; i++)
195             {
196                 _mm256_store_pd(base + o[0 + i], _mm256_add_pd(_mm256_load_pd(base + o[0 + i]),
197                                                                _mm512_castpd512_pd256(t[i])));
198                 _mm256_store_pd(base + o[4 + i], _mm256_add_pd(_mm256_load_pd(base + o[4 + i]),
199                                                                _mm512_extractf64x4_pd(t[i], 1)));
200             }
201         }
202         else
203         {
204             for (int i = 0; i < 4; i++)
205             {
206                 _mm256_storeu_pd(base + o[0 + i], _mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]),
207                                                                 _mm512_castpd512_pd256(t[i])));
208                 _mm256_storeu_pd(base + o[4 + i], _mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]),
209                                                                 _mm512_extractf64x4_pd(t[i], 1)));
210             }
211         }
212     }
213 }
214
215 template<int align>
216 static inline void gmx_simdcall
217                    transposeScatterDecrU(double* base, const std::int32_t offset[], SimdDouble v0, SimdDouble v1, SimdDouble v2)
218 {
219     __m512d                                  t[4], t5, t6, t7, t8;
220     alignas(GMX_SIMD_ALIGNMENT) std::int64_t o[8];
221     // TODO: should use fastMultiply
222     _mm512_store_epi64(o, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(
223                                   _mm256_load_si256((const __m256i*)(offset)), _mm256_set1_epi32(align))));
224     t5   = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
225     t6   = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
226     t7   = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
227     t8   = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
228     t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
229     t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
230     t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
231     t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
232     if (align < 4)
233     {
234         for (int i = 0; i < 4; i++)
235         {
236             _mm512_mask_storeu_pd(base + o[0 + i], avx512Int2Mask(7),
237                                   _mm512_castpd256_pd512(_mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]),
238                                                                        _mm512_castpd512_pd256(t[i]))));
239             _mm512_mask_storeu_pd(
240                     base + o[4 + i], avx512Int2Mask(7),
241                     _mm512_castpd256_pd512(_mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]),
242                                                          _mm512_extractf64x4_pd(t[i], 1))));
243         }
244     }
245     else
246     {
247         if (align % 4 == 0)
248         {
249             for (int i = 0; i < 4; i++)
250             {
251                 _mm256_store_pd(base + o[0 + i], _mm256_sub_pd(_mm256_load_pd(base + o[0 + i]),
252                                                                _mm512_castpd512_pd256(t[i])));
253                 _mm256_store_pd(base + o[4 + i], _mm256_sub_pd(_mm256_load_pd(base + o[4 + i]),
254                                                                _mm512_extractf64x4_pd(t[i], 1)));
255             }
256         }
257         else
258         {
259             for (int i = 0; i < 4; i++)
260             {
261                 _mm256_storeu_pd(base + o[0 + i], _mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]),
262                                                                 _mm512_castpd512_pd256(t[i])));
263                 _mm256_storeu_pd(base + o[4 + i], _mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]),
264                                                                 _mm512_extractf64x4_pd(t[i], 1)));
265             }
266         }
267     }
268 }
269
270 static inline void gmx_simdcall expandScalarsToTriplets(SimdDouble  scalar,
271                                                         SimdDouble* triplets0,
272                                                         SimdDouble* triplets1,
273                                                         SimdDouble* triplets2)
274 {
275     triplets0->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(
276             _mm512_set_epi32(5, 4, 5, 4, 3, 2, 3, 2, 3, 2, 1, 0, 1, 0, 1, 0),
277             _mm512_castpd_si512(scalar.simdInternal_)));
278     triplets1->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(
279             _mm512_set_epi32(11, 10, 9, 8, 9, 8, 9, 8, 7, 6, 7, 6, 7, 6, 5, 4),
280             _mm512_castpd_si512(scalar.simdInternal_)));
281     triplets2->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(
282             _mm512_set_epi32(15, 14, 15, 14, 15, 14, 13, 12, 13, 12, 13, 12, 11, 10, 11, 10),
283             _mm512_castpd_si512(scalar.simdInternal_)));
284 }
285
286
287 static inline double gmx_simdcall
288                      reduceIncr4ReturnSum(double* m, SimdDouble v0, SimdDouble v1, SimdDouble v2, SimdDouble v3)
289 {
290     __m512d t0, t2;
291     __m256d t3, t4;
292
293     assert(std::size_t(m) % 32 == 0);
294
295     t0 = _mm512_add_pd(v0.simdInternal_, _mm512_permute_pd(v0.simdInternal_, 0x55));
296     t2 = _mm512_add_pd(v2.simdInternal_, _mm512_permute_pd(v2.simdInternal_, 0x55));
297     t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xAA), v1.simdInternal_,
298                             _mm512_permute_pd(v1.simdInternal_, 0x55));
299     t2 = _mm512_mask_add_pd(t2, avx512Int2Mask(0xAA), v3.simdInternal_,
300                             _mm512_permute_pd(v3.simdInternal_, 0x55));
301     t0 = _mm512_add_pd(t0, _mm512_shuffle_f64x2(t0, t0, 0x4E));
302     t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xF0), t2, _mm512_shuffle_f64x2(t2, t2, 0x4E));
303     t0 = _mm512_add_pd(t0, _mm512_shuffle_f64x2(t0, t0, 0xB1));
304     t0 = _mm512_mask_shuffle_f64x2(t0, avx512Int2Mask(0x0C), t0, t0, 0xEE);
305
306     t3 = _mm512_castpd512_pd256(t0);
307     t4 = _mm256_load_pd(m);
308     t4 = _mm256_add_pd(t4, t3);
309     _mm256_store_pd(m, t4);
310
311     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0x4E));
312     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
313
314     return _mm_cvtsd_f64(_mm512_castpd512_pd128(t0));
315 }
316
317 static inline SimdDouble gmx_simdcall loadDualHsimd(const double* m0, const double* m1)
318 {
319     assert(std::size_t(m0) % 32 == 0);
320     assert(std::size_t(m1) % 32 == 0);
321
322     return { _mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_load_pd(m0)), _mm256_load_pd(m1), 1) };
323 }
324
325 static inline SimdDouble gmx_simdcall loadDuplicateHsimd(const double* m)
326 {
327     assert(std::size_t(m) % 32 == 0);
328
329     return { _mm512_broadcast_f64x4(_mm256_load_pd(m)) };
330 }
331
332 static inline SimdDouble gmx_simdcall loadU1DualHsimd(const double* m)
333 {
334     return { _mm512_insertf64x4(_mm512_broadcastsd_pd(_mm_load_sd(m)),
335                                 _mm256_broadcastsd_pd(_mm_load_sd(m + 1)), 1) };
336 }
337
338
339 static inline void gmx_simdcall storeDualHsimd(double* m0, double* m1, SimdDouble a)
340 {
341     assert(std::size_t(m0) % 32 == 0);
342     assert(std::size_t(m1) % 32 == 0);
343
344     _mm256_store_pd(m0, _mm512_castpd512_pd256(a.simdInternal_));
345     _mm256_store_pd(m1, _mm512_extractf64x4_pd(a.simdInternal_, 1));
346 }
347
348 static inline void gmx_simdcall incrDualHsimd(double* m0, double* m1, SimdDouble a)
349 {
350     assert(std::size_t(m0) % 32 == 0);
351     assert(std::size_t(m1) % 32 == 0);
352
353     __m256d x;
354
355     // Lower half
356     x = _mm256_load_pd(m0);
357     x = _mm256_add_pd(x, _mm512_castpd512_pd256(a.simdInternal_));
358     _mm256_store_pd(m0, x);
359
360     // Upper half
361     x = _mm256_load_pd(m1);
362     x = _mm256_add_pd(x, _mm512_extractf64x4_pd(a.simdInternal_, 1));
363     _mm256_store_pd(m1, x);
364 }
365
366 static inline void gmx_simdcall decr3Hsimd(double* m, SimdDouble a0, SimdDouble a1, SimdDouble a2)
367 {
368     decrHsimd(m, a0);
369     decrHsimd(m + GMX_SIMD_DOUBLE_WIDTH / 2, a1);
370     decrHsimd(m + GMX_SIMD_DOUBLE_WIDTH, a2);
371 }
372
373 template<int align>
374 static inline void gmx_simdcall gatherLoadTransposeHsimd(const double*      base0,
375                                                          const double*      base1,
376                                                          const std::int32_t offset[],
377                                                          SimdDouble*        v0,
378                                                          SimdDouble*        v1)
379 {
380     __m128i idx0, idx1;
381     __m256i idx;
382     __m512d tmp1, tmp2;
383
384     assert(std::size_t(offset) % 16 == 0);
385     assert(std::size_t(base0) % 16 == 0);
386     assert(std::size_t(base1) % 16 == 0);
387
388     idx0 = _mm_load_si128(reinterpret_cast<const __m128i*>(offset));
389
390     static_assert(align == 2 || align == 4, "If more are needed use fastMultiply");
391     idx0 = _mm_slli_epi32(idx0, align == 2 ? 1 : 2);
392
393     idx1 = _mm_add_epi32(idx0, _mm_set1_epi32(1));
394
395     idx = _mm256_inserti128_si256(_mm256_castsi128_si256(idx0), idx1, 1);
396
397     constexpr size_t scale = sizeof(double);
398     tmp1 = _mm512_i32gather_pd(idx, base0, scale); // TODO: Might be faster to use invidual loads
399     tmp2 = _mm512_i32gather_pd(idx, base1, scale);
400
401     v0->simdInternal_ = _mm512_shuffle_f64x2(tmp1, tmp2, 0x44);
402     v1->simdInternal_ = _mm512_shuffle_f64x2(tmp1, tmp2, 0xEE);
403 }
404
405 static inline double gmx_simdcall reduceIncr4ReturnSumHsimd(double* m, SimdDouble v0, SimdDouble v1)
406 {
407     __m512d t0;
408     __m256d t2, t3;
409
410     assert(std::size_t(m) % 32 == 0);
411
412     t0 = _mm512_add_pd(v0.simdInternal_, _mm512_permutex_pd(v0.simdInternal_, 0x4E));
413     t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xCC), v1.simdInternal_,
414                             _mm512_permutex_pd(v1.simdInternal_, 0x4E));
415     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
416     t0 = _mm512_mask_shuffle_f64x2(t0, avx512Int2Mask(0xAA), t0, t0, 0xEE);
417
418     t2 = _mm512_castpd512_pd256(t0);
419     t3 = _mm256_load_pd(m);
420     t3 = _mm256_add_pd(t3, t2);
421     _mm256_store_pd(m, t3);
422
423     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0x4E));
424     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
425
426     return _mm_cvtsd_f64(_mm512_castpd512_pd128(t0));
427 }
428
429 static inline SimdDouble gmx_simdcall loadU4NOffset(const double* m, int offset)
430 {
431     return { _mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_loadu_pd(m)),
432                                 _mm256_loadu_pd(m + offset), 1) };
433 }
434
435 } // namespace gmx
436
437 #endif // GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H