Implement changes for CMake policy 0068
[alexxy/gromacs.git] / src / gromacs / simd / impl_x86_avx_512 / impl_x86_avx_512_util_double.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014,2015,2016,2017, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 #ifndef GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
37 #define GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
38
39 #include "config.h"
40
41 #include <cassert>
42 #include <cstdint>
43
44 #include <immintrin.h>
45
46 #include "gromacs/utility/basedefinitions.h"
47
48 #include "impl_x86_avx_512_general.h"
49 #include "impl_x86_avx_512_simd_double.h"
50
51 namespace gmx
52 {
53
54 static const int c_simdBestPairAlignmentDouble = 2;
55
56 namespace
57 {
58 // Multiply function optimized for powers of 2, for which it is done by
59 // shifting. Currently up to 8 is accelerated. Could be accelerated for any
60 // number with a constexpr log2 function.
61 template<int n>
62 SimdDInt32 fastMultiply(SimdDInt32 x)
63 {
64     if (n == 2)
65     {
66         return _mm256_slli_epi32(x.simdInternal_, 1);
67     }
68     else if (n == 4)
69     {
70         return _mm256_slli_epi32(x.simdInternal_, 2);
71     }
72     else if (n == 8)
73     {
74         return _mm256_slli_epi32(x.simdInternal_, 3);
75     }
76     else
77     {
78         return x * n;
79     }
80 }
81
82 template<int align>
83 static inline void gmx_simdcall
84 gatherLoadBySimdIntTranspose(const double *, SimdDInt32)
85 {
86     //Nothing to do. Termination of recursion.
87 }
88 }
89
90
91 template <int align, typename ... Targs>
92 static inline void gmx_simdcall
93 gatherLoadBySimdIntTranspose(const double * base, SimdDInt32 offset, SimdDouble *v, Targs... Fargs)
94 {
95     if (align > 1)
96     {
97         offset = fastMultiply<align>(offset);
98     }
99     constexpr size_t scale = sizeof(double);
100     v->simdInternal_ = _mm512_i32gather_pd(offset.simdInternal_, base, scale);
101     gatherLoadBySimdIntTranspose<1>(base+1, offset, Fargs ...);
102 }
103
104 template <int align, typename ... Targs>
105 static inline void gmx_simdcall
106 gatherLoadUBySimdIntTranspose(const double *base, SimdDInt32 offset, Targs... Fargs)
107 {
108     gatherLoadBySimdIntTranspose<align>(base, offset, Fargs ...);
109 }
110
111 template <int align, typename ... Targs>
112 static inline void gmx_simdcall
113 gatherLoadTranspose(const double *base, const std::int32_t offset[], Targs... Fargs)
114 {
115     gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdDInt32Tag()), Fargs ...);
116 }
117
118 template <int align, typename ... Targs>
119 static inline void gmx_simdcall
120 gatherLoadUTranspose(const double *base, const std::int32_t offset[], Targs... Fargs)
121 {
122     gatherLoadTranspose<align>(base, offset, Fargs ...);
123 }
124
125 template <int align>
126 static inline void gmx_simdcall
127 transposeScatterStoreU(double *             base,
128                        const std::int32_t   offset[],
129                        SimdDouble           v0,
130                        SimdDouble           v1,
131                        SimdDouble           v2)
132 {
133     SimdDInt32 simdoffset = simdLoad(offset, SimdDInt32Tag());
134
135     if (align > 1)
136     {
137         simdoffset = fastMultiply<align>(simdoffset);;
138     }
139     constexpr size_t scale = sizeof(double);
140     _mm512_i32scatter_pd(base,   simdoffset.simdInternal_, v0.simdInternal_, scale);
141     _mm512_i32scatter_pd(&(base[1]), simdoffset.simdInternal_, v1.simdInternal_, scale);
142     _mm512_i32scatter_pd(&(base[2]), simdoffset.simdInternal_, v2.simdInternal_, scale);
143 }
144
145 template <int align>
146 static inline void gmx_simdcall
147 transposeScatterIncrU(double *            base,
148                       const std::int32_t  offset[],
149                       SimdDouble          v0,
150                       SimdDouble          v1,
151                       SimdDouble          v2)
152 {
153     __m512d t[4], t5, t6, t7, t8;
154     alignas(GMX_SIMD_ALIGNMENT) std::int64_t    o[8];
155     //TODO: should use fastMultiply
156     _mm512_store_epi64(o, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(_mm256_load_si256((const __m256i*)(offset  )), _mm256_set1_epi32(align))));
157     t5   = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
158     t6   = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
159     t7   = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
160     t8   = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
161     t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
162     t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
163     t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
164     t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
165     if (align < 4)
166     {
167         for (int i = 0; i < 4; i++)
168         {
169             _mm512_mask_storeu_pd(base + o[0 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
170                                           _mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i]))));
171             _mm512_mask_storeu_pd(base + o[4 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
172                                           _mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1))));
173         }
174     }
175     else
176     {
177         if (align % 4 == 0)
178         {
179             for (int i = 0; i < 4; i++)
180             {
181                 _mm256_store_pd(base + o[0 + i],
182                                 _mm256_add_pd(_mm256_load_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
183                 _mm256_store_pd(base + o[4 + i],
184                                 _mm256_add_pd(_mm256_load_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
185             }
186         }
187         else
188         {
189             for (int i = 0; i < 4; i++)
190             {
191                 _mm256_storeu_pd(base + o[0 + i],
192                                  _mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
193                 _mm256_storeu_pd(base + o[4 + i],
194                                  _mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
195             }
196         }
197     }
198 }
199
200 template <int align>
201 static inline void gmx_simdcall
202 transposeScatterDecrU(double *            base,
203                       const std::int32_t  offset[],
204                       SimdDouble          v0,
205                       SimdDouble          v1,
206                       SimdDouble          v2)
207 {
208     __m512d t[4], t5, t6, t7, t8;
209     alignas(GMX_SIMD_ALIGNMENT) std::int64_t    o[8];
210     //TODO: should use fastMultiply
211     _mm512_store_epi64(o, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(_mm256_load_si256((const __m256i*)(offset  )), _mm256_set1_epi32(align))));
212     t5   = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
213     t6   = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
214     t7   = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
215     t8   = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
216     t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
217     t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
218     t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
219     t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
220     if (align < 4)
221     {
222         for (int i = 0; i < 4; i++)
223         {
224             _mm512_mask_storeu_pd(base + o[0 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
225                                           _mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i]))));
226             _mm512_mask_storeu_pd(base + o[4 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
227                                           _mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1))));
228         }
229     }
230     else
231     {
232         if (align % 4 == 0)
233         {
234             for (int i = 0; i < 4; i++)
235             {
236                 _mm256_store_pd(base + o[0 + i],
237                                 _mm256_sub_pd(_mm256_load_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
238                 _mm256_store_pd(base + o[4 + i],
239                                 _mm256_sub_pd(_mm256_load_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
240             }
241         }
242         else
243         {
244             for (int i = 0; i < 4; i++)
245             {
246                 _mm256_storeu_pd(base + o[0 + i],
247                                  _mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
248                 _mm256_storeu_pd(base + o[4 + i],
249                                  _mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
250             }
251         }
252     }
253 }
254
255 static inline void gmx_simdcall
256 expandScalarsToTriplets(SimdDouble    scalar,
257                         SimdDouble *  triplets0,
258                         SimdDouble *  triplets1,
259                         SimdDouble *  triplets2)
260 {
261     triplets0->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(_mm512_set_epi32(5, 4, 5, 4, 3, 2, 3, 2, 3, 2, 1, 0, 1, 0, 1, 0),
262                                                                             _mm512_castpd_si512(scalar.simdInternal_)));
263     triplets1->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(_mm512_set_epi32(11, 10, 9, 8, 9, 8, 9, 8, 7, 6, 7, 6, 7, 6, 5, 4),
264                                                                             _mm512_castpd_si512(scalar.simdInternal_)));
265     triplets2->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(_mm512_set_epi32(15, 14, 15, 14, 15, 14, 13, 12, 13, 12, 13, 12, 11, 10, 11, 10),
266                                                                             _mm512_castpd_si512(scalar.simdInternal_)));
267 }
268
269
270 static inline double gmx_simdcall
271 reduceIncr4ReturnSum(double *    m,
272                      SimdDouble  v0,
273                      SimdDouble  v1,
274                      SimdDouble  v2,
275                      SimdDouble  v3)
276 {
277     __m512d t0, t2;
278     __m256d t3, t4;
279
280     assert(std::size_t(m) % 32 == 0);
281
282     t0 = _mm512_add_pd(v0.simdInternal_, _mm512_permute_pd(v0.simdInternal_, 0x55));
283     t2 = _mm512_add_pd(v2.simdInternal_, _mm512_permute_pd(v2.simdInternal_, 0x55));
284     t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xAA), v1.simdInternal_, _mm512_permute_pd(v1.simdInternal_, 0x55));
285     t2 = _mm512_mask_add_pd(t2, avx512Int2Mask(0xAA), v3.simdInternal_, _mm512_permute_pd(v3.simdInternal_, 0x55));
286     t0 = _mm512_add_pd(t0, _mm512_shuffle_f64x2(t0, t0, 0x4E));
287     t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xF0), t2, _mm512_shuffle_f64x2(t2, t2, 0x4E));
288     t0 = _mm512_add_pd(t0, _mm512_shuffle_f64x2(t0, t0, 0xB1));
289     t0 = _mm512_mask_shuffle_f64x2(t0, avx512Int2Mask(0x0C), t0, t0, 0xEE);
290
291     t3 = _mm512_castpd512_pd256(t0);
292     t4 = _mm256_load_pd(m);
293     t4 = _mm256_add_pd(t4, t3);
294     _mm256_store_pd(m, t4);
295
296     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0x4E));
297     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
298
299     return _mm_cvtsd_f64(_mm512_castpd512_pd128(t0));
300 }
301
302 static inline SimdDouble gmx_simdcall
303 loadDualHsimd(const double * m0,
304               const double * m1)
305 {
306     assert(std::size_t(m0) % 32 == 0);
307     assert(std::size_t(m1) % 32 == 0);
308
309     return {
310                _mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_load_pd(m0)),
311                                   _mm256_load_pd(m1), 1)
312     };
313 }
314
315 static inline SimdDouble gmx_simdcall
316 loadDuplicateHsimd(const double * m)
317 {
318     assert(std::size_t(m) % 32 == 0);
319
320     return {
321                _mm512_broadcast_f64x4(_mm256_load_pd(m))
322     };
323 }
324
325 static inline SimdDouble gmx_simdcall
326 loadU1DualHsimd(const double * m)
327 {
328     return {
329                _mm512_insertf64x4(_mm512_broadcastsd_pd(_mm_load_sd(m)),
330                                   _mm256_broadcastsd_pd(_mm_load_sd(m+1)), 1)
331     };
332 }
333
334
335 static inline void gmx_simdcall
336 storeDualHsimd(double *     m0,
337                double *     m1,
338                SimdDouble   a)
339 {
340     assert(std::size_t(m0) % 32 == 0);
341     assert(std::size_t(m1) % 32 == 0);
342
343     _mm256_store_pd(m0, _mm512_castpd512_pd256(a.simdInternal_));
344     _mm256_store_pd(m1, _mm512_extractf64x4_pd(a.simdInternal_, 1));
345 }
346
347 static inline void gmx_simdcall
348 incrDualHsimd(double *     m0,
349               double *     m1,
350               SimdDouble   a)
351 {
352     assert(std::size_t(m0) % 32 == 0);
353     assert(std::size_t(m1) % 32 == 0);
354
355     __m256d x;
356
357     // Lower half
358     x = _mm256_load_pd(m0);
359     x = _mm256_add_pd(x, _mm512_castpd512_pd256(a.simdInternal_));
360     _mm256_store_pd(m0, x);
361
362     // Upper half
363     x = _mm256_load_pd(m1);
364     x = _mm256_add_pd(x, _mm512_extractf64x4_pd(a.simdInternal_, 1));
365     _mm256_store_pd(m1, x);
366 }
367
368 static inline void gmx_simdcall
369 decrHsimd(double *    m,
370           SimdDouble  a)
371 {
372     __m256d t;
373
374     assert(std::size_t(m) % 32 == 0);
375
376     a.simdInternal_ = _mm512_add_pd(a.simdInternal_, _mm512_shuffle_f64x2(a.simdInternal_, a.simdInternal_, 0xEE));
377     t               = _mm256_load_pd(m);
378     t               = _mm256_sub_pd(t, _mm512_castpd512_pd256(a.simdInternal_));
379     _mm256_store_pd(m, t);
380 }
381
382
383 template <int align>
384 static inline void gmx_simdcall
385 gatherLoadTransposeHsimd(const double *       base0,
386                          const double *       base1,
387                          const std::int32_t   offset[],
388                          SimdDouble *         v0,
389                          SimdDouble *         v1)
390 {
391     __m128i  idx0, idx1;
392     __m256i  idx;
393     __m512d  tmp1, tmp2;
394
395     assert(std::size_t(offset) % 16 == 0);
396     assert(std::size_t(base0) % 16 == 0);
397     assert(std::size_t(base1) % 16 == 0);
398
399     idx0 = _mm_load_si128(reinterpret_cast<const __m128i*>(offset));
400
401     static_assert(align == 2 || align == 4, "If more are needed use fastMultiply");
402     idx0 = _mm_slli_epi32(idx0, align == 2 ? 1 : 2);
403
404     idx1 = _mm_add_epi32(idx0, _mm_set1_epi32(1));
405
406     idx = _mm256_inserti128_si256(_mm256_castsi128_si256(idx0), idx1, 1);
407
408     constexpr size_t scale = sizeof(double);
409     tmp1 = _mm512_i32gather_pd(idx, base0, scale); //TODO: Might be faster to use invidual loads
410     tmp2 = _mm512_i32gather_pd(idx, base1, scale);
411
412     v0->simdInternal_ = _mm512_shuffle_f64x2(tmp1, tmp2, 0x44 );
413     v1->simdInternal_ = _mm512_shuffle_f64x2(tmp1, tmp2, 0xEE );
414 }
415
416 static inline double gmx_simdcall
417 reduceIncr4ReturnSumHsimd(double *     m,
418                           SimdDouble   v0,
419                           SimdDouble   v1)
420 {
421     __m512d  t0;
422     __m256d  t2, t3;
423
424     assert(std::size_t(m) % 32 == 0);
425
426     t0 = _mm512_add_pd(v0.simdInternal_, _mm512_permutex_pd(v0.simdInternal_, 0x4E));
427     t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xCC), v1.simdInternal_, _mm512_permutex_pd(v1.simdInternal_, 0x4E));
428     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
429     t0 = _mm512_mask_shuffle_f64x2(t0, avx512Int2Mask(0xAA), t0, t0, 0xEE);
430
431     t2 = _mm512_castpd512_pd256(t0);
432     t3 = _mm256_load_pd(m);
433     t3 = _mm256_add_pd(t3, t2);
434     _mm256_store_pd(m, t3);
435
436     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0x4E));
437     t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
438
439     return _mm_cvtsd_f64(_mm512_castpd512_pd128(t0));
440 }
441
442 static inline SimdDouble gmx_simdcall
443 loadU4NOffset(const double *m, int offset)
444 {
445     return {
446                _mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_loadu_pd(m)),
447                                   _mm256_loadu_pd(m+offset), 1)
448     };
449 }
450
451 }      // namespace gmx
452
453 #endif // GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H