90ab9ef5a45810027dd27a7f45f1f62fa8b046d6
[alexxy/gromacs.git] / src / gromacs / simd / impl_arm_neon_asimd / impl_arm_neon_asimd_simd_float.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014,2015,2016,2017,2019,2020, by the GROMACS development team.
5  * Copyright (c) 2021, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36 #ifndef GMX_SIMD_IMPL_ARM_NEON_ASIMD_SIMD_FLOAT_H
37 #define GMX_SIMD_IMPL_ARM_NEON_ASIMD_SIMD_FLOAT_H
38
39 #include "config.h"
40
41 #include <cassert>
42 #include <cstddef>
43 #include <cstdint>
44
45 #include <arm_neon.h>
46
47 #include "gromacs/math/utilities.h"
48
49 namespace gmx
50 {
51
52 class SimdFloat
53 {
54 public:
55     SimdFloat() {}
56
57     SimdFloat(float f) : simdInternal_(vdupq_n_f32(f)) {}
58
59     // Internal utility constructor to simplify return statements
60     SimdFloat(float32x4_t simd) : simdInternal_(simd) {}
61
62     float32x4_t simdInternal_;
63 };
64
65 class SimdFInt32
66 {
67 public:
68     SimdFInt32() {}
69
70     SimdFInt32(std::int32_t i) : simdInternal_(vdupq_n_s32(i)) {}
71
72     // Internal utility constructor to simplify return statements
73     SimdFInt32(int32x4_t simd) : simdInternal_(simd) {}
74
75     int32x4_t simdInternal_;
76 };
77
78 class SimdFBool
79 {
80 public:
81     SimdFBool() {}
82
83     SimdFBool(bool b) : simdInternal_(vdupq_n_u32(b ? 0xFFFFFFFF : 0)) {}
84
85     // Internal utility constructor to simplify return statements
86     SimdFBool(uint32x4_t simd) : simdInternal_(simd) {}
87
88     uint32x4_t simdInternal_;
89 };
90
91 class SimdFIBool
92 {
93 public:
94     SimdFIBool() {}
95
96     SimdFIBool(bool b) : simdInternal_(vdupq_n_u32(b ? 0xFFFFFFFF : 0)) {}
97
98     // Internal utility constructor to simplify return statements
99     SimdFIBool(uint32x4_t simd) : simdInternal_(simd) {}
100
101     uint32x4_t simdInternal_;
102 };
103
104 static inline SimdFloat gmx_simdcall simdLoad(const float* m, SimdFloatTag = {})
105 {
106     assert(std::size_t(m) % 16 == 0);
107     return { vld1q_f32(m) };
108 }
109
110 static inline void gmx_simdcall store(float* m, SimdFloat a)
111 {
112     assert(std::size_t(m) % 16 == 0);
113     vst1q_f32(m, a.simdInternal_);
114 }
115
116 static inline SimdFloat gmx_simdcall simdLoadU(const float* m, SimdFloatTag = {})
117 {
118     return { vld1q_f32(m) };
119 }
120
121 static inline void gmx_simdcall storeU(float* m, SimdFloat a)
122 {
123     vst1q_f32(m, a.simdInternal_);
124 }
125
126 static inline SimdFloat gmx_simdcall setZeroF()
127 {
128     return { vdupq_n_f32(0.0F) };
129 }
130
131 static inline SimdFInt32 gmx_simdcall simdLoad(const std::int32_t* m, SimdFInt32Tag)
132 {
133     assert(std::size_t(m) % 16 == 0);
134     return { vld1q_s32(m) };
135 }
136
137 static inline void gmx_simdcall store(std::int32_t* m, SimdFInt32 a)
138 {
139     assert(std::size_t(m) % 16 == 0);
140     vst1q_s32(m, a.simdInternal_);
141 }
142
143 static inline SimdFInt32 gmx_simdcall simdLoadU(const std::int32_t* m, SimdFInt32Tag)
144 {
145     return { vld1q_s32(m) };
146 }
147
148 static inline void gmx_simdcall storeU(std::int32_t* m, SimdFInt32 a)
149 {
150     vst1q_s32(m, a.simdInternal_);
151 }
152
153 static inline SimdFInt32 gmx_simdcall setZeroFI()
154 {
155     return { vdupq_n_s32(0) };
156 }
157
158 template<int index>
159 gmx_simdcall static inline std::int32_t extract(SimdFInt32 a)
160 {
161     return vgetq_lane_s32(a.simdInternal_, index);
162 }
163
164 static inline SimdFloat gmx_simdcall operator&(SimdFloat a, SimdFloat b)
165 {
166     return { vreinterpretq_f32_s32(vandq_s32(vreinterpretq_s32_f32(a.simdInternal_),
167                                              vreinterpretq_s32_f32(b.simdInternal_))) };
168 }
169
170 static inline SimdFloat gmx_simdcall andNot(SimdFloat a, SimdFloat b)
171 {
172     return { vreinterpretq_f32_s32(vbicq_s32(vreinterpretq_s32_f32(b.simdInternal_),
173                                              vreinterpretq_s32_f32(a.simdInternal_))) };
174 }
175
176 static inline SimdFloat gmx_simdcall operator|(SimdFloat a, SimdFloat b)
177 {
178     return { vreinterpretq_f32_s32(vorrq_s32(vreinterpretq_s32_f32(a.simdInternal_),
179                                              vreinterpretq_s32_f32(b.simdInternal_))) };
180 }
181
182 static inline SimdFloat gmx_simdcall operator^(SimdFloat a, SimdFloat b)
183 {
184     return { vreinterpretq_f32_s32(veorq_s32(vreinterpretq_s32_f32(a.simdInternal_),
185                                              vreinterpretq_s32_f32(b.simdInternal_))) };
186 }
187
188 static inline SimdFloat gmx_simdcall operator+(SimdFloat a, SimdFloat b)
189 {
190     return { vaddq_f32(a.simdInternal_, b.simdInternal_) };
191 }
192
193 static inline SimdFloat gmx_simdcall operator-(SimdFloat a, SimdFloat b)
194 {
195     return { vsubq_f32(a.simdInternal_, b.simdInternal_) };
196 }
197
198 static inline SimdFloat gmx_simdcall operator-(SimdFloat x)
199 {
200     return { vnegq_f32(x.simdInternal_) };
201 }
202
203 static inline SimdFloat gmx_simdcall operator*(SimdFloat a, SimdFloat b)
204 {
205     return { vmulq_f32(a.simdInternal_, b.simdInternal_) };
206 }
207
208 static inline SimdFloat gmx_simdcall rsqrt(SimdFloat x)
209 {
210     return { vrsqrteq_f32(x.simdInternal_) };
211 }
212
213 // The SIMD implementation seems to overflow when we square lu for
214 // values close to FLOAT_MAX, so we fall back on the version in
215 // simd_math.h, which is probably slightly slower.
216 #if GMX_SIMD_HAVE_NATIVE_RSQRT_ITER_FLOAT
217 static inline SimdFloat gmx_simdcall rsqrtIter(SimdFloat lu, SimdFloat x)
218 {
219     return { vmulq_f32(lu.simdInternal_,
220                        vrsqrtsq_f32(vmulq_f32(lu.simdInternal_, lu.simdInternal_), x.simdInternal_)) };
221 }
222 #endif
223
224 static inline SimdFloat gmx_simdcall rcp(SimdFloat x)
225 {
226     return { vrecpeq_f32(x.simdInternal_) };
227 }
228
229 static inline SimdFloat gmx_simdcall rcpIter(SimdFloat lu, SimdFloat x)
230 {
231     return { vmulq_f32(lu.simdInternal_, vrecpsq_f32(lu.simdInternal_, x.simdInternal_)) };
232 }
233
234 static inline SimdFloat gmx_simdcall maskAdd(SimdFloat a, SimdFloat b, SimdFBool m)
235 {
236     b.simdInternal_ =
237             vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(b.simdInternal_), m.simdInternal_));
238
239     return { vaddq_f32(a.simdInternal_, b.simdInternal_) };
240 }
241
242 static inline SimdFloat gmx_simdcall maskzMul(SimdFloat a, SimdFloat b, SimdFBool m)
243 {
244     SimdFloat tmp = a * b;
245
246     return { vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(tmp.simdInternal_), m.simdInternal_)) };
247 }
248
249 static inline SimdFloat gmx_simdcall maskzFma(SimdFloat a, SimdFloat b, SimdFloat c, SimdFBool m)
250 {
251 #ifdef __ARM_FEATURE_FMA
252     float32x4_t tmp = vfmaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_);
253 #else
254     float32x4_t tmp = vmlaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_);
255 #endif
256
257     return { vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(tmp), m.simdInternal_)) };
258 }
259
260 static inline SimdFloat gmx_simdcall maskzRsqrt(SimdFloat x, SimdFBool m)
261 {
262     // The result will always be correct since we mask the result with m, but
263     // for debug builds we also want to make sure not to generate FP exceptions
264 #ifndef NDEBUG
265     x.simdInternal_ = vbslq_f32(m.simdInternal_, x.simdInternal_, vdupq_n_f32(1.0F));
266 #endif
267     return { vreinterpretq_f32_u32(
268             vandq_u32(vreinterpretq_u32_f32(vrsqrteq_f32(x.simdInternal_)), m.simdInternal_)) };
269 }
270
271 static inline SimdFloat gmx_simdcall maskzRcp(SimdFloat x, SimdFBool m)
272 {
273     // The result will always be correct since we mask the result with m, but
274     // for debug builds we also want to make sure not to generate FP exceptions
275 #ifndef NDEBUG
276     x.simdInternal_ = vbslq_f32(m.simdInternal_, x.simdInternal_, vdupq_n_f32(1.0F));
277 #endif
278     return { vreinterpretq_f32_u32(
279             vandq_u32(vreinterpretq_u32_f32(vrecpeq_f32(x.simdInternal_)), m.simdInternal_)) };
280 }
281
282 static inline SimdFloat gmx_simdcall abs(SimdFloat x)
283 {
284     return { vabsq_f32(x.simdInternal_) };
285 }
286
287 static inline SimdFloat gmx_simdcall max(SimdFloat a, SimdFloat b)
288 {
289     return { vmaxq_f32(a.simdInternal_, b.simdInternal_) };
290 }
291
292 static inline SimdFloat gmx_simdcall min(SimdFloat a, SimdFloat b)
293 {
294     return { vminq_f32(a.simdInternal_, b.simdInternal_) };
295 }
296
297 // Round and trunc operations are defined at the end of this file, since they
298 // need to use float-to-integer and integer-to-float conversions.
299
300 template<MathOptimization opt = MathOptimization::Safe>
301 static inline SimdFloat gmx_simdcall frexp(SimdFloat value, SimdFInt32* exponent)
302 {
303     const int32x4_t exponentMask = vdupq_n_s32(0x7F800000);
304     const int32x4_t mantissaMask = vdupq_n_s32(0x807FFFFF);
305     const int32x4_t exponentBias = vdupq_n_s32(126); // add 1 to make our definition identical to frexp()
306     const float32x4_t half = vdupq_n_f32(0.5F);
307     int32x4_t         iExponent;
308
309     iExponent = vandq_s32(vreinterpretq_s32_f32(value.simdInternal_), exponentMask);
310     iExponent = vsubq_s32(vshrq_n_s32(iExponent, 23), exponentBias);
311
312     float32x4_t result = vreinterpretq_f32_s32(
313             vorrq_s32(vandq_s32(vreinterpretq_s32_f32(value.simdInternal_), mantissaMask),
314                       vreinterpretq_s32_f32(half)));
315
316     if (opt == MathOptimization::Safe)
317     {
318         uint32x4_t valueIsZero = vceqq_f32(value.simdInternal_, vdupq_n_f32(0.0F));
319         iExponent              = vbicq_s32(iExponent, vreinterpretq_s32_u32(valueIsZero));
320         result                 = vbslq_f32(valueIsZero, value.simdInternal_, result);
321     }
322
323     exponent->simdInternal_ = iExponent;
324     return { result };
325 }
326
327 template<MathOptimization opt = MathOptimization::Safe>
328 static inline SimdFloat gmx_simdcall ldexp(SimdFloat value, SimdFInt32 exponent)
329 {
330     const int32x4_t exponentBias = vdupq_n_s32(127);
331     int32x4_t       iExponent    = vaddq_s32(exponent.simdInternal_, exponentBias);
332
333     if (opt == MathOptimization::Safe)
334     {
335         // Make sure biased argument is not negative
336         iExponent = vmaxq_s32(iExponent, vdupq_n_s32(0));
337     }
338
339     iExponent = vshlq_n_s32(iExponent, 23);
340
341     return { vmulq_f32(value.simdInternal_, vreinterpretq_f32_s32(iExponent)) };
342 }
343
344 static inline SimdFBool gmx_simdcall operator==(SimdFloat a, SimdFloat b)
345 {
346     return { vceqq_f32(a.simdInternal_, b.simdInternal_) };
347 }
348
349 static inline SimdFBool gmx_simdcall operator!=(SimdFloat a, SimdFloat b)
350 {
351     return { vmvnq_u32(vceqq_f32(a.simdInternal_, b.simdInternal_)) };
352 }
353
354 static inline SimdFBool gmx_simdcall operator<(SimdFloat a, SimdFloat b)
355 {
356     return { vcltq_f32(a.simdInternal_, b.simdInternal_) };
357 }
358
359 static inline SimdFBool gmx_simdcall operator<=(SimdFloat a, SimdFloat b)
360 {
361     return { vcleq_f32(a.simdInternal_, b.simdInternal_) };
362 }
363
364 static inline SimdFBool gmx_simdcall testBits(SimdFloat a)
365 {
366     uint32x4_t tmp = vreinterpretq_u32_f32(a.simdInternal_);
367
368     return { vtstq_u32(tmp, tmp) };
369 }
370
371 static inline SimdFBool gmx_simdcall operator&&(SimdFBool a, SimdFBool b)
372 {
373
374     return { vandq_u32(a.simdInternal_, b.simdInternal_) };
375 }
376
377 static inline SimdFBool gmx_simdcall operator||(SimdFBool a, SimdFBool b)
378 {
379     return { vorrq_u32(a.simdInternal_, b.simdInternal_) };
380 }
381
382 static inline SimdFloat gmx_simdcall selectByMask(SimdFloat a, SimdFBool m)
383 {
384     return { vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a.simdInternal_), m.simdInternal_)) };
385 }
386
387 static inline SimdFloat gmx_simdcall selectByNotMask(SimdFloat a, SimdFBool m)
388 {
389     return { vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a.simdInternal_), m.simdInternal_)) };
390 }
391
392 static inline SimdFloat gmx_simdcall blend(SimdFloat a, SimdFloat b, SimdFBool sel)
393 {
394     return { vbslq_f32(sel.simdInternal_, b.simdInternal_, a.simdInternal_) };
395 }
396
397 static inline SimdFInt32 gmx_simdcall operator&(SimdFInt32 a, SimdFInt32 b)
398 {
399     return { vandq_s32(a.simdInternal_, b.simdInternal_) };
400 }
401
402 static inline SimdFInt32 gmx_simdcall andNot(SimdFInt32 a, SimdFInt32 b)
403 {
404     return { vbicq_s32(b.simdInternal_, a.simdInternal_) };
405 }
406
407 static inline SimdFInt32 gmx_simdcall operator|(SimdFInt32 a, SimdFInt32 b)
408 {
409     return { vorrq_s32(a.simdInternal_, b.simdInternal_) };
410 }
411
412 static inline SimdFInt32 gmx_simdcall operator^(SimdFInt32 a, SimdFInt32 b)
413 {
414     return { veorq_s32(a.simdInternal_, b.simdInternal_) };
415 }
416
417 static inline SimdFInt32 gmx_simdcall operator+(SimdFInt32 a, SimdFInt32 b)
418 {
419     return { vaddq_s32(a.simdInternal_, b.simdInternal_) };
420 }
421
422 static inline SimdFInt32 gmx_simdcall operator-(SimdFInt32 a, SimdFInt32 b)
423 {
424     return { vsubq_s32(a.simdInternal_, b.simdInternal_) };
425 }
426
427 static inline SimdFInt32 gmx_simdcall operator*(SimdFInt32 a, SimdFInt32 b)
428 {
429     return { vmulq_s32(a.simdInternal_, b.simdInternal_) };
430 }
431
432 static inline SimdFIBool gmx_simdcall operator==(SimdFInt32 a, SimdFInt32 b)
433 {
434     return { vceqq_s32(a.simdInternal_, b.simdInternal_) };
435 }
436
437 static inline SimdFIBool gmx_simdcall testBits(SimdFInt32 a)
438 {
439     return { vtstq_s32(a.simdInternal_, a.simdInternal_) };
440 }
441
442 static inline SimdFIBool gmx_simdcall operator<(SimdFInt32 a, SimdFInt32 b)
443 {
444     return { vcltq_s32(a.simdInternal_, b.simdInternal_) };
445 }
446
447 static inline SimdFIBool gmx_simdcall operator&&(SimdFIBool a, SimdFIBool b)
448 {
449     return { vandq_u32(a.simdInternal_, b.simdInternal_) };
450 }
451
452 static inline SimdFIBool gmx_simdcall operator||(SimdFIBool a, SimdFIBool b)
453 {
454     return { vorrq_u32(a.simdInternal_, b.simdInternal_) };
455 }
456
457 static inline SimdFInt32 gmx_simdcall selectByMask(SimdFInt32 a, SimdFIBool m)
458 {
459     return { vandq_s32(a.simdInternal_, vreinterpretq_s32_u32(m.simdInternal_)) };
460 }
461
462 static inline SimdFInt32 gmx_simdcall selectByNotMask(SimdFInt32 a, SimdFIBool m)
463 {
464     return { vbicq_s32(a.simdInternal_, vreinterpretq_s32_u32(m.simdInternal_)) };
465 }
466
467 static inline SimdFInt32 gmx_simdcall blend(SimdFInt32 a, SimdFInt32 b, SimdFIBool sel)
468 {
469     return { vbslq_s32(sel.simdInternal_, b.simdInternal_, a.simdInternal_) };
470 }
471
472 static inline SimdFInt32 gmx_simdcall cvttR2I(SimdFloat a)
473 {
474     return { vcvtq_s32_f32(a.simdInternal_) };
475 }
476
477 static inline SimdFloat gmx_simdcall cvtI2R(SimdFInt32 a)
478 {
479     return { vcvtq_f32_s32(a.simdInternal_) };
480 }
481
482 static inline SimdFIBool gmx_simdcall cvtB2IB(SimdFBool a)
483 {
484     return { a.simdInternal_ };
485 }
486
487 static inline SimdFBool gmx_simdcall cvtIB2B(SimdFIBool a)
488 {
489     return { a.simdInternal_ };
490 }
491
492 static inline SimdFloat gmx_simdcall fma(SimdFloat a, SimdFloat b, SimdFloat c)
493 {
494     return { vfmaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_) };
495 }
496
497 static inline SimdFloat gmx_simdcall fms(SimdFloat a, SimdFloat b, SimdFloat c)
498 {
499     return { vnegq_f32(vfmsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)) };
500 }
501
502 static inline SimdFloat gmx_simdcall fnma(SimdFloat a, SimdFloat b, SimdFloat c)
503 {
504     return { vfmsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_) };
505 }
506
507 static inline SimdFloat gmx_simdcall fnms(SimdFloat a, SimdFloat b, SimdFloat c)
508 {
509     return { vnegq_f32(vfmaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)) };
510 }
511
512 static inline SimdFloat gmx_simdcall round(SimdFloat x)
513 {
514     return { vrndnq_f32(x.simdInternal_) };
515 }
516
517 static inline SimdFloat gmx_simdcall trunc(SimdFloat x)
518 {
519     return { vrndq_f32(x.simdInternal_) };
520 }
521
522 static inline SimdFInt32 gmx_simdcall cvtR2I(SimdFloat a)
523 {
524     return { vcvtnq_s32_f32(a.simdInternal_) };
525 }
526
527 static inline bool gmx_simdcall anyTrue(SimdFBool a)
528 {
529     return (vmaxvq_u32(a.simdInternal_) != 0);
530 }
531
532 static inline bool gmx_simdcall anyTrue(SimdFIBool a)
533 {
534     return (vmaxvq_u32(a.simdInternal_) != 0);
535 }
536
537 static inline float gmx_simdcall reduce(SimdFloat a)
538 {
539     float32x4_t b = a.simdInternal_;
540     b             = vpaddq_f32(b, b);
541     b             = vpaddq_f32(b, b);
542     return vgetq_lane_f32(b, 0);
543 }
544
545 } // namespace gmx
546
547 #endif // GMX_SIMD_IMPL_ARM_NEON_ASIMD_SIMD_FLOAT_H