3f537e10f9b787d3cf435d201480ac1517b06915
[alexxy/gromacs.git] / src / gromacs / simd / impl_x86_avx_256 / impl_x86_avx_256_simd_float.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014,2015,2016,2017, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 #ifndef GMX_SIMD_IMPL_X86_AVX_256_SIMD_FLOAT_H
37 #define GMX_SIMD_IMPL_X86_AVX_256_SIMD_FLOAT_H
38
39 #include "config.h"
40
41 #include <cassert>
42 #include <cstddef>
43 #include <cstdint>
44
45 #include <immintrin.h>
46
47 #include "gromacs/math/utilities.h"
48
49 namespace gmx
50 {
51
52 class SimdFloat
53 {
54     public:
55         SimdFloat() {}
56
57         SimdFloat(float f) : simdInternal_(_mm256_set1_ps(f)) {}
58
59         // Internal utility constructor to simplify return statements
60         SimdFloat(__m256 simd) : simdInternal_(simd) {}
61
62         __m256  simdInternal_;
63 };
64
65 class SimdFInt32
66 {
67     public:
68         SimdFInt32() {}
69
70         SimdFInt32(std::int32_t i) : simdInternal_(_mm256_set1_epi32(i)) {}
71
72         // Internal utility constructor to simplify return statements
73         SimdFInt32(__m256i simd) : simdInternal_(simd) {}
74
75         __m256i  simdInternal_;
76 };
77
78 class SimdFBool
79 {
80     public:
81         SimdFBool() {}
82
83         SimdFBool(bool b) : simdInternal_(_mm256_castsi256_ps(_mm256_set1_epi32( b ? 0xFFFFFFFF : 0))) {}
84
85         // Internal utility constructor to simplify return statements
86         SimdFBool(__m256 simd) : simdInternal_(simd) {}
87
88         __m256  simdInternal_;
89 };
90
91 static inline SimdFloat gmx_simdcall
92 simdLoad(const float *m, SimdFloatTag = {})
93 {
94     assert(std::size_t(m) % 32 == 0);
95     return {
96                _mm256_load_ps(m)
97     };
98 }
99
100 static inline void gmx_simdcall
101 store(float *m, SimdFloat a)
102 {
103     assert(std::size_t(m) % 32 == 0);
104     _mm256_store_ps(m, a.simdInternal_);
105 }
106
107 static inline SimdFloat gmx_simdcall
108 simdLoadU(const float *m, SimdFloatTag = {})
109 {
110     return {
111                _mm256_loadu_ps(m)
112     };
113 }
114
115 static inline void gmx_simdcall
116 storeU(float *m, SimdFloat a)
117 {
118     _mm256_storeu_ps(m, a.simdInternal_);
119 }
120
121 static inline SimdFloat gmx_simdcall
122 setZeroF()
123 {
124     return {
125                _mm256_setzero_ps()
126     };
127 }
128
129 static inline SimdFInt32 gmx_simdcall
130 simdLoad(const std::int32_t * m, SimdFInt32Tag)
131 {
132     assert(std::size_t(m) % 32 == 0);
133     return {
134                _mm256_load_si256(reinterpret_cast<const __m256i *>(m))
135     };
136 }
137
138 static inline void gmx_simdcall
139 store(std::int32_t * m, SimdFInt32 a)
140 {
141     assert(std::size_t(m) % 32 == 0);
142     _mm256_store_si256(reinterpret_cast<__m256i *>(m), a.simdInternal_);
143 }
144
145 static inline SimdFInt32 gmx_simdcall
146 simdLoadU(const std::int32_t *m, SimdFInt32Tag)
147 {
148     return {
149                _mm256_loadu_si256(reinterpret_cast<const __m256i *>(m))
150     };
151 }
152
153 static inline void gmx_simdcall
154 storeU(std::int32_t * m, SimdFInt32 a)
155 {
156     _mm256_storeu_si256(reinterpret_cast<__m256i *>(m), a.simdInternal_);
157 }
158
159 static inline SimdFInt32 gmx_simdcall
160 setZeroFI()
161 {
162     return {
163                _mm256_setzero_si256()
164     };
165 }
166
167 template<int index>
168 static inline std::int32_t gmx_simdcall
169 extract(SimdFInt32 a)
170 {
171     return _mm_extract_epi32(_mm256_extractf128_si256(a.simdInternal_, index>>2), index & 0x3);
172 }
173
174 static inline SimdFloat gmx_simdcall
175 operator&(SimdFloat a, SimdFloat b)
176 {
177     return {
178                _mm256_and_ps(a.simdInternal_, b.simdInternal_)
179     };
180 }
181
182 static inline SimdFloat gmx_simdcall
183 andNot(SimdFloat a, SimdFloat b)
184 {
185     return {
186                _mm256_andnot_ps(a.simdInternal_, b.simdInternal_)
187     };
188 }
189
190 static inline SimdFloat gmx_simdcall
191 operator|(SimdFloat a, SimdFloat b)
192 {
193     return {
194                _mm256_or_ps(a.simdInternal_, b.simdInternal_)
195     };
196 }
197
198 static inline SimdFloat gmx_simdcall
199 operator^(SimdFloat a, SimdFloat b)
200 {
201     return {
202                _mm256_xor_ps(a.simdInternal_, b.simdInternal_)
203     };
204 }
205
206 static inline SimdFloat gmx_simdcall
207 operator+(SimdFloat a, SimdFloat b)
208 {
209     return {
210                _mm256_add_ps(a.simdInternal_, b.simdInternal_)
211     };
212 }
213
214 static inline SimdFloat gmx_simdcall
215 operator-(SimdFloat a, SimdFloat b)
216 {
217     return {
218                _mm256_sub_ps(a.simdInternal_, b.simdInternal_)
219     };
220 }
221
222 static inline SimdFloat gmx_simdcall
223 operator-(SimdFloat x)
224 {
225     return {
226                _mm256_xor_ps(x.simdInternal_, _mm256_set1_ps(GMX_FLOAT_NEGZERO))
227     };
228 }
229
230 static inline SimdFloat gmx_simdcall
231 operator*(SimdFloat a, SimdFloat b)
232 {
233     return {
234                _mm256_mul_ps(a.simdInternal_, b.simdInternal_)
235     };
236 }
237
238 // Override for AVX2 and higher
239 #if GMX_SIMD_X86_AVX_256
240 static inline SimdFloat gmx_simdcall
241 fma(SimdFloat a, SimdFloat b, SimdFloat c)
242 {
243     return {
244                _mm256_add_ps(_mm256_mul_ps(a.simdInternal_, b.simdInternal_), c.simdInternal_)
245     };
246 }
247
248 static inline SimdFloat gmx_simdcall
249 fms(SimdFloat a, SimdFloat b, SimdFloat c)
250 {
251     return {
252                _mm256_sub_ps(_mm256_mul_ps(a.simdInternal_, b.simdInternal_), c.simdInternal_)
253     };
254 }
255
256 static inline SimdFloat gmx_simdcall
257 fnma(SimdFloat a, SimdFloat b, SimdFloat c)
258 {
259     return {
260                _mm256_sub_ps(c.simdInternal_, _mm256_mul_ps(a.simdInternal_, b.simdInternal_))
261     };
262 }
263
264 static inline SimdFloat gmx_simdcall
265 fnms(SimdFloat a, SimdFloat b, SimdFloat c)
266 {
267     return {
268                _mm256_sub_ps(_mm256_setzero_ps(), _mm256_add_ps(_mm256_mul_ps(a.simdInternal_, b.simdInternal_), c.simdInternal_))
269     };
270 }
271 #endif
272
273 static inline SimdFloat gmx_simdcall
274 rsqrt(SimdFloat x)
275 {
276     return {
277                _mm256_rsqrt_ps(x.simdInternal_)
278     };
279 }
280
281 static inline SimdFloat gmx_simdcall
282 rcp(SimdFloat x)
283 {
284     return {
285                _mm256_rcp_ps(x.simdInternal_)
286     };
287 }
288
289 static inline SimdFloat gmx_simdcall
290 maskAdd(SimdFloat a, SimdFloat b, SimdFBool m)
291 {
292     return {
293                _mm256_add_ps(a.simdInternal_, _mm256_and_ps(b.simdInternal_, m.simdInternal_))
294     };
295 }
296
297 static inline SimdFloat gmx_simdcall
298 maskzMul(SimdFloat a, SimdFloat b, SimdFBool m)
299 {
300     return {
301                _mm256_and_ps(_mm256_mul_ps(a.simdInternal_, b.simdInternal_), m.simdInternal_)
302     };
303 }
304
305 static inline SimdFloat
306 maskzFma(SimdFloat a, SimdFloat b, SimdFloat c, SimdFBool m)
307 {
308     return {
309                _mm256_and_ps(_mm256_add_ps(_mm256_mul_ps(a.simdInternal_, b.simdInternal_), c.simdInternal_), m.simdInternal_)
310     };
311 }
312
313 static inline SimdFloat
314 maskzRsqrt(SimdFloat x, SimdFBool m)
315 {
316 #ifndef NDEBUG
317     x.simdInternal_ = _mm256_blendv_ps(_mm256_set1_ps(1.0f), x.simdInternal_, m.simdInternal_);
318 #endif
319     return {
320                _mm256_and_ps(_mm256_rsqrt_ps(x.simdInternal_), m.simdInternal_)
321     };
322 }
323
324 static inline SimdFloat
325 maskzRcp(SimdFloat x, SimdFBool m)
326 {
327 #ifndef NDEBUG
328     x.simdInternal_ = _mm256_blendv_ps(_mm256_set1_ps(1.0f), x.simdInternal_, m.simdInternal_);
329 #endif
330     return {
331                _mm256_and_ps(_mm256_rcp_ps(x.simdInternal_), m.simdInternal_)
332     };
333 }
334
335 static inline SimdFloat gmx_simdcall
336 abs(SimdFloat x)
337 {
338     return {
339                _mm256_andnot_ps( _mm256_set1_ps(GMX_FLOAT_NEGZERO), x.simdInternal_ )
340     };
341 }
342
343 static inline SimdFloat gmx_simdcall
344 max(SimdFloat a, SimdFloat b)
345 {
346     return {
347                _mm256_max_ps(a.simdInternal_, b.simdInternal_)
348     };
349 }
350
351 static inline SimdFloat gmx_simdcall
352 min(SimdFloat a, SimdFloat b)
353 {
354     return {
355                _mm256_min_ps(a.simdInternal_, b.simdInternal_)
356     };
357 }
358
359 static inline SimdFloat gmx_simdcall
360 round(SimdFloat x)
361 {
362     return {
363                _mm256_round_ps(x.simdInternal_, _MM_FROUND_NINT)
364     };
365 }
366
367 static inline SimdFloat gmx_simdcall
368 trunc(SimdFloat x)
369 {
370     return {
371                _mm256_round_ps(x.simdInternal_, _MM_FROUND_TRUNC)
372     };
373 }
374
375 // Override for AVX2 and higher
376 #if GMX_SIMD_X86_AVX_256
377 static inline SimdFloat gmx_simdcall
378 frexp(SimdFloat value, SimdFInt32 * exponent)
379 {
380     const __m256  exponentMask      = _mm256_castsi256_ps(_mm256_set1_epi32(0x7F800000));
381     const __m256  mantissaMask      = _mm256_castsi256_ps(_mm256_set1_epi32(0x807FFFFF));
382     const __m256  half              = _mm256_set1_ps(0.5);
383     const __m128i exponentBias      = _mm_set1_epi32(126);  // add 1 to make our definition identical to frexp()
384     __m256i       iExponent;
385     __m128i       iExponentLow, iExponentHigh;
386
387     iExponent               = _mm256_castps_si256(_mm256_and_ps(value.simdInternal_, exponentMask));
388     iExponentHigh           = _mm256_extractf128_si256(iExponent, 0x1);
389     iExponentLow            = _mm256_castsi256_si128(iExponent);
390     iExponentLow            = _mm_srli_epi32(iExponentLow, 23);
391     iExponentHigh           = _mm_srli_epi32(iExponentHigh, 23);
392     iExponentLow            = _mm_sub_epi32(iExponentLow, exponentBias);
393     iExponentHigh           = _mm_sub_epi32(iExponentHigh, exponentBias);
394     iExponent               = _mm256_castsi128_si256(iExponentLow);
395     exponent->simdInternal_ = _mm256_insertf128_si256(iExponent, iExponentHigh, 0x1);
396
397     return {
398                _mm256_or_ps(_mm256_and_ps(value.simdInternal_, mantissaMask), half)
399     };
400
401 }
402
403 template <MathOptimization opt = MathOptimization::Safe>
404 static inline SimdFloat gmx_simdcall
405 ldexp(SimdFloat value, SimdFInt32 exponent)
406 {
407     const __m128i exponentBias      = _mm_set1_epi32(127);
408     __m256i       iExponent;
409     __m128i       iExponentLow, iExponentHigh;
410
411     iExponentHigh = _mm256_extractf128_si256(exponent.simdInternal_, 0x1);
412     iExponentLow  = _mm256_castsi256_si128(exponent.simdInternal_);
413
414     iExponentLow  = _mm_add_epi32(iExponentLow, exponentBias);
415     iExponentHigh = _mm_add_epi32(iExponentHigh, exponentBias);
416
417     if (opt == MathOptimization::Safe)
418     {
419         // Make sure biased argument is not negative
420         iExponentLow  = _mm_max_epi32(iExponentLow, _mm_setzero_si128());
421         iExponentHigh = _mm_max_epi32(iExponentHigh, _mm_setzero_si128());
422     }
423
424     iExponentLow  = _mm_slli_epi32(iExponentLow, 23);
425     iExponentHigh = _mm_slli_epi32(iExponentHigh, 23);
426     iExponent     = _mm256_castsi128_si256(iExponentLow);
427     iExponent     = _mm256_insertf128_si256(iExponent, iExponentHigh, 0x1);
428     return {
429                _mm256_mul_ps(value.simdInternal_, _mm256_castsi256_ps(iExponent))
430     };
431 }
432 #endif
433
434 static inline float gmx_simdcall
435 reduce(SimdFloat a)
436 {
437     __m128 t0;
438     t0 = _mm_add_ps(_mm256_castps256_ps128(a.simdInternal_), _mm256_extractf128_ps(a.simdInternal_, 0x1));
439     t0 = _mm_add_ps(t0, _mm_permute_ps(t0, _MM_SHUFFLE(1, 0, 3, 2)));
440     t0 = _mm_add_ss(t0, _mm_permute_ps(t0, _MM_SHUFFLE(0, 3, 2, 1)));
441     return *reinterpret_cast<float *>(&t0);
442 }
443
444 static inline SimdFBool gmx_simdcall
445 operator==(SimdFloat a, SimdFloat b)
446 {
447     return {
448                _mm256_cmp_ps(a.simdInternal_, b.simdInternal_, _CMP_EQ_OQ)
449     };
450 }
451
452 static inline SimdFBool gmx_simdcall
453 operator!=(SimdFloat a, SimdFloat b)
454 {
455     return {
456                _mm256_cmp_ps(a.simdInternal_, b.simdInternal_, _CMP_NEQ_OQ)
457     };
458 }
459
460 static inline SimdFBool gmx_simdcall
461 operator<(SimdFloat a, SimdFloat b)
462 {
463     return {
464                _mm256_cmp_ps(a.simdInternal_, b.simdInternal_, _CMP_LT_OQ)
465     };
466 }
467
468 static inline SimdFBool gmx_simdcall
469 operator<=(SimdFloat a, SimdFloat b)
470 {
471     return {
472                _mm256_cmp_ps(a.simdInternal_, b.simdInternal_, _CMP_LE_OQ)
473     };
474 }
475
476 // Override for AVX2 and higher
477 #if GMX_SIMD_X86_AVX_256
478 static inline SimdFBool gmx_simdcall
479 testBits(SimdFloat a)
480 {
481     __m256 tst = _mm256_cvtepi32_ps(_mm256_castps_si256(a.simdInternal_));
482
483     return {
484                _mm256_cmp_ps(tst, _mm256_setzero_ps(), _CMP_NEQ_OQ)
485     };
486 }
487 #endif
488
489 static inline SimdFBool gmx_simdcall
490 operator&&(SimdFBool a, SimdFBool b)
491 {
492     return {
493                _mm256_and_ps(a.simdInternal_, b.simdInternal_)
494     };
495 }
496
497 static inline SimdFBool gmx_simdcall
498 operator||(SimdFBool a, SimdFBool b)
499 {
500     return {
501                _mm256_or_ps(a.simdInternal_, b.simdInternal_)
502     };
503 }
504
505 static inline bool gmx_simdcall
506 anyTrue(SimdFBool a) { return _mm256_movemask_ps(a.simdInternal_) != 0; }
507
508 static inline SimdFloat gmx_simdcall
509 selectByMask(SimdFloat a, SimdFBool mask)
510 {
511     return {
512                _mm256_and_ps(a.simdInternal_, mask.simdInternal_)
513     };
514 }
515
516 static inline SimdFloat gmx_simdcall
517 selectByNotMask(SimdFloat a, SimdFBool mask)
518 {
519     return {
520                _mm256_andnot_ps(mask.simdInternal_, a.simdInternal_)
521     };
522 }
523
524 static inline SimdFloat gmx_simdcall
525 blend(SimdFloat a, SimdFloat b, SimdFBool sel)
526 {
527     return {
528                _mm256_blendv_ps(a.simdInternal_, b.simdInternal_, sel.simdInternal_)
529     };
530 }
531
532 static inline SimdFInt32 gmx_simdcall
533 cvtR2I(SimdFloat a)
534 {
535     return {
536                _mm256_cvtps_epi32(a.simdInternal_)
537     };
538 }
539
540 static inline SimdFInt32 gmx_simdcall
541 cvttR2I(SimdFloat a)
542 {
543     return {
544                _mm256_cvttps_epi32(a.simdInternal_)
545     };
546 }
547
548 static inline SimdFloat gmx_simdcall
549 cvtI2R(SimdFInt32 a)
550 {
551     return {
552                _mm256_cvtepi32_ps(a.simdInternal_)
553     };
554 }
555
556 }      // namespace gmx
557
558 #endif // GMX_SIMD_IMPL_X86_AVX_256_SIMD_FLOAT_H