2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017,2019,2020, by the GROMACS development team.
5 * Copyright (c) 2021, by the GROMACS development team, led by
6 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7 * and including many others, as listed in the AUTHORS file in the
8 * top-level source directory and at http://www.gromacs.org.
10 * GROMACS is free software; you can redistribute it and/or
11 * modify it under the terms of the GNU Lesser General Public License
12 * as published by the Free Software Foundation; either version 2.1
13 * of the License, or (at your option) any later version.
15 * GROMACS is distributed in the hope that it will be useful,
16 * but WITHOUT ANY WARRANTY; without even the implied warranty of
17 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18 * Lesser General Public License for more details.
20 * You should have received a copy of the GNU Lesser General Public
21 * License along with GROMACS; if not, see
22 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
25 * If you want to redistribute modifications to GROMACS, please
26 * consider that scientific software is very special. Version
27 * control is crucial - bugs must be traceable. We will be happy to
28 * consider code for inclusion in the official distribution, but
29 * derived work must not be called official GROMACS. Details are found
30 * in the README & COPYING files - if they are missing, get the
31 * official version at http://www.gromacs.org.
33 * To help us fund GROMACS development, we humbly ask that you cite
34 * the research papers on the package. Check out http://www.gromacs.org.
36 #ifndef GMX_SIMD_IMPL_ARM_NEON_ASIMD_SIMD_FLOAT_H
37 #define GMX_SIMD_IMPL_ARM_NEON_ASIMD_SIMD_FLOAT_H
47 #include "gromacs/math/utilities.h"
57 SimdFloat(float f) : simdInternal_(vdupq_n_f32(f)) {}
59 // Internal utility constructor to simplify return statements
60 SimdFloat(float32x4_t simd) : simdInternal_(simd) {}
62 float32x4_t simdInternal_;
70 SimdFInt32(std::int32_t i) : simdInternal_(vdupq_n_s32(i)) {}
72 // Internal utility constructor to simplify return statements
73 SimdFInt32(int32x4_t simd) : simdInternal_(simd) {}
75 int32x4_t simdInternal_;
83 SimdFBool(bool b) : simdInternal_(vdupq_n_u32(b ? 0xFFFFFFFF : 0)) {}
85 // Internal utility constructor to simplify return statements
86 SimdFBool(uint32x4_t simd) : simdInternal_(simd) {}
88 uint32x4_t simdInternal_;
96 SimdFIBool(bool b) : simdInternal_(vdupq_n_u32(b ? 0xFFFFFFFF : 0)) {}
98 // Internal utility constructor to simplify return statements
99 SimdFIBool(uint32x4_t simd) : simdInternal_(simd) {}
101 uint32x4_t simdInternal_;
104 static inline SimdFloat gmx_simdcall simdLoad(const float* m, SimdFloatTag = {})
106 assert(std::size_t(m) % 16 == 0);
107 return { vld1q_f32(m) };
110 static inline void gmx_simdcall store(float* m, SimdFloat a)
112 assert(std::size_t(m) % 16 == 0);
113 vst1q_f32(m, a.simdInternal_);
116 static inline SimdFloat gmx_simdcall simdLoadU(const float* m, SimdFloatTag = {})
118 return { vld1q_f32(m) };
121 static inline void gmx_simdcall storeU(float* m, SimdFloat a)
123 vst1q_f32(m, a.simdInternal_);
126 static inline SimdFloat gmx_simdcall setZeroF()
128 return { vdupq_n_f32(0.0F) };
131 static inline SimdFInt32 gmx_simdcall simdLoad(const std::int32_t* m, SimdFInt32Tag)
133 assert(std::size_t(m) % 16 == 0);
134 return { vld1q_s32(m) };
137 static inline void gmx_simdcall store(std::int32_t* m, SimdFInt32 a)
139 assert(std::size_t(m) % 16 == 0);
140 vst1q_s32(m, a.simdInternal_);
143 static inline SimdFInt32 gmx_simdcall simdLoadU(const std::int32_t* m, SimdFInt32Tag)
145 return { vld1q_s32(m) };
148 static inline void gmx_simdcall storeU(std::int32_t* m, SimdFInt32 a)
150 vst1q_s32(m, a.simdInternal_);
153 static inline SimdFInt32 gmx_simdcall setZeroFI()
155 return { vdupq_n_s32(0) };
159 gmx_simdcall static inline std::int32_t extract(SimdFInt32 a)
161 return vgetq_lane_s32(a.simdInternal_, index);
164 static inline SimdFloat gmx_simdcall operator&(SimdFloat a, SimdFloat b)
166 return { vreinterpretq_f32_s32(vandq_s32(vreinterpretq_s32_f32(a.simdInternal_),
167 vreinterpretq_s32_f32(b.simdInternal_))) };
170 static inline SimdFloat gmx_simdcall andNot(SimdFloat a, SimdFloat b)
172 return { vreinterpretq_f32_s32(vbicq_s32(vreinterpretq_s32_f32(b.simdInternal_),
173 vreinterpretq_s32_f32(a.simdInternal_))) };
176 static inline SimdFloat gmx_simdcall operator|(SimdFloat a, SimdFloat b)
178 return { vreinterpretq_f32_s32(vorrq_s32(vreinterpretq_s32_f32(a.simdInternal_),
179 vreinterpretq_s32_f32(b.simdInternal_))) };
182 static inline SimdFloat gmx_simdcall operator^(SimdFloat a, SimdFloat b)
184 return { vreinterpretq_f32_s32(veorq_s32(vreinterpretq_s32_f32(a.simdInternal_),
185 vreinterpretq_s32_f32(b.simdInternal_))) };
188 static inline SimdFloat gmx_simdcall operator+(SimdFloat a, SimdFloat b)
190 return { vaddq_f32(a.simdInternal_, b.simdInternal_) };
193 static inline SimdFloat gmx_simdcall operator-(SimdFloat a, SimdFloat b)
195 return { vsubq_f32(a.simdInternal_, b.simdInternal_) };
198 static inline SimdFloat gmx_simdcall operator-(SimdFloat x)
200 return { vnegq_f32(x.simdInternal_) };
203 static inline SimdFloat gmx_simdcall operator*(SimdFloat a, SimdFloat b)
205 return { vmulq_f32(a.simdInternal_, b.simdInternal_) };
208 static inline SimdFloat gmx_simdcall rsqrt(SimdFloat x)
210 return { vrsqrteq_f32(x.simdInternal_) };
213 // The SIMD implementation seems to overflow when we square lu for
214 // values close to FLOAT_MAX, so we fall back on the version in
215 // simd_math.h, which is probably slightly slower.
216 #if GMX_SIMD_HAVE_NATIVE_RSQRT_ITER_FLOAT
217 static inline SimdFloat gmx_simdcall rsqrtIter(SimdFloat lu, SimdFloat x)
219 return { vmulq_f32(lu.simdInternal_,
220 vrsqrtsq_f32(vmulq_f32(lu.simdInternal_, lu.simdInternal_), x.simdInternal_)) };
224 static inline SimdFloat gmx_simdcall rcp(SimdFloat x)
226 return { vrecpeq_f32(x.simdInternal_) };
229 static inline SimdFloat gmx_simdcall rcpIter(SimdFloat lu, SimdFloat x)
231 return { vmulq_f32(lu.simdInternal_, vrecpsq_f32(lu.simdInternal_, x.simdInternal_)) };
234 static inline SimdFloat gmx_simdcall maskAdd(SimdFloat a, SimdFloat b, SimdFBool m)
237 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(b.simdInternal_), m.simdInternal_));
239 return { vaddq_f32(a.simdInternal_, b.simdInternal_) };
242 static inline SimdFloat gmx_simdcall maskzMul(SimdFloat a, SimdFloat b, SimdFBool m)
244 SimdFloat tmp = a * b;
246 return { vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(tmp.simdInternal_), m.simdInternal_)) };
249 static inline SimdFloat gmx_simdcall maskzFma(SimdFloat a, SimdFloat b, SimdFloat c, SimdFBool m)
251 #ifdef __ARM_FEATURE_FMA
252 float32x4_t tmp = vfmaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_);
254 float32x4_t tmp = vmlaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_);
257 return { vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(tmp), m.simdInternal_)) };
260 static inline SimdFloat gmx_simdcall maskzRsqrt(SimdFloat x, SimdFBool m)
262 // The result will always be correct since we mask the result with m, but
263 // for debug builds we also want to make sure not to generate FP exceptions
265 x.simdInternal_ = vbslq_f32(m.simdInternal_, x.simdInternal_, vdupq_n_f32(1.0F));
267 return { vreinterpretq_f32_u32(
268 vandq_u32(vreinterpretq_u32_f32(vrsqrteq_f32(x.simdInternal_)), m.simdInternal_)) };
271 static inline SimdFloat gmx_simdcall maskzRcp(SimdFloat x, SimdFBool m)
273 // The result will always be correct since we mask the result with m, but
274 // for debug builds we also want to make sure not to generate FP exceptions
276 x.simdInternal_ = vbslq_f32(m.simdInternal_, x.simdInternal_, vdupq_n_f32(1.0F));
278 return { vreinterpretq_f32_u32(
279 vandq_u32(vreinterpretq_u32_f32(vrecpeq_f32(x.simdInternal_)), m.simdInternal_)) };
282 static inline SimdFloat gmx_simdcall abs(SimdFloat x)
284 return { vabsq_f32(x.simdInternal_) };
287 static inline SimdFloat gmx_simdcall max(SimdFloat a, SimdFloat b)
289 return { vmaxq_f32(a.simdInternal_, b.simdInternal_) };
292 static inline SimdFloat gmx_simdcall min(SimdFloat a, SimdFloat b)
294 return { vminq_f32(a.simdInternal_, b.simdInternal_) };
297 // Round and trunc operations are defined at the end of this file, since they
298 // need to use float-to-integer and integer-to-float conversions.
300 template<MathOptimization opt = MathOptimization::Safe>
301 static inline SimdFloat gmx_simdcall frexp(SimdFloat value, SimdFInt32* exponent)
303 const int32x4_t exponentMask = vdupq_n_s32(0x7F800000);
304 const int32x4_t mantissaMask = vdupq_n_s32(0x807FFFFF);
305 const int32x4_t exponentBias = vdupq_n_s32(126); // add 1 to make our definition identical to frexp()
306 const float32x4_t half = vdupq_n_f32(0.5F);
309 iExponent = vandq_s32(vreinterpretq_s32_f32(value.simdInternal_), exponentMask);
310 iExponent = vsubq_s32(vshrq_n_s32(iExponent, 23), exponentBias);
312 float32x4_t result = vreinterpretq_f32_s32(
313 vorrq_s32(vandq_s32(vreinterpretq_s32_f32(value.simdInternal_), mantissaMask),
314 vreinterpretq_s32_f32(half)));
316 if (opt == MathOptimization::Safe)
318 uint32x4_t valueIsZero = vceqq_f32(value.simdInternal_, vdupq_n_f32(0.0F));
319 iExponent = vbicq_s32(iExponent, vreinterpretq_s32_u32(valueIsZero));
320 result = vbslq_f32(valueIsZero, value.simdInternal_, result);
323 exponent->simdInternal_ = iExponent;
327 template<MathOptimization opt = MathOptimization::Safe>
328 static inline SimdFloat gmx_simdcall ldexp(SimdFloat value, SimdFInt32 exponent)
330 const int32x4_t exponentBias = vdupq_n_s32(127);
331 int32x4_t iExponent = vaddq_s32(exponent.simdInternal_, exponentBias);
333 if (opt == MathOptimization::Safe)
335 // Make sure biased argument is not negative
336 iExponent = vmaxq_s32(iExponent, vdupq_n_s32(0));
339 iExponent = vshlq_n_s32(iExponent, 23);
341 return { vmulq_f32(value.simdInternal_, vreinterpretq_f32_s32(iExponent)) };
344 static inline SimdFBool gmx_simdcall operator==(SimdFloat a, SimdFloat b)
346 return { vceqq_f32(a.simdInternal_, b.simdInternal_) };
349 static inline SimdFBool gmx_simdcall operator!=(SimdFloat a, SimdFloat b)
351 return { vmvnq_u32(vceqq_f32(a.simdInternal_, b.simdInternal_)) };
354 static inline SimdFBool gmx_simdcall operator<(SimdFloat a, SimdFloat b)
356 return { vcltq_f32(a.simdInternal_, b.simdInternal_) };
359 static inline SimdFBool gmx_simdcall operator<=(SimdFloat a, SimdFloat b)
361 return { vcleq_f32(a.simdInternal_, b.simdInternal_) };
364 static inline SimdFBool gmx_simdcall testBits(SimdFloat a)
366 uint32x4_t tmp = vreinterpretq_u32_f32(a.simdInternal_);
368 return { vtstq_u32(tmp, tmp) };
371 static inline SimdFBool gmx_simdcall operator&&(SimdFBool a, SimdFBool b)
374 return { vandq_u32(a.simdInternal_, b.simdInternal_) };
377 static inline SimdFBool gmx_simdcall operator||(SimdFBool a, SimdFBool b)
379 return { vorrq_u32(a.simdInternal_, b.simdInternal_) };
382 static inline SimdFloat gmx_simdcall selectByMask(SimdFloat a, SimdFBool m)
384 return { vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a.simdInternal_), m.simdInternal_)) };
387 static inline SimdFloat gmx_simdcall selectByNotMask(SimdFloat a, SimdFBool m)
389 return { vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a.simdInternal_), m.simdInternal_)) };
392 static inline SimdFloat gmx_simdcall blend(SimdFloat a, SimdFloat b, SimdFBool sel)
394 return { vbslq_f32(sel.simdInternal_, b.simdInternal_, a.simdInternal_) };
397 static inline SimdFInt32 gmx_simdcall operator&(SimdFInt32 a, SimdFInt32 b)
399 return { vandq_s32(a.simdInternal_, b.simdInternal_) };
402 static inline SimdFInt32 gmx_simdcall andNot(SimdFInt32 a, SimdFInt32 b)
404 return { vbicq_s32(b.simdInternal_, a.simdInternal_) };
407 static inline SimdFInt32 gmx_simdcall operator|(SimdFInt32 a, SimdFInt32 b)
409 return { vorrq_s32(a.simdInternal_, b.simdInternal_) };
412 static inline SimdFInt32 gmx_simdcall operator^(SimdFInt32 a, SimdFInt32 b)
414 return { veorq_s32(a.simdInternal_, b.simdInternal_) };
417 static inline SimdFInt32 gmx_simdcall operator+(SimdFInt32 a, SimdFInt32 b)
419 return { vaddq_s32(a.simdInternal_, b.simdInternal_) };
422 static inline SimdFInt32 gmx_simdcall operator-(SimdFInt32 a, SimdFInt32 b)
424 return { vsubq_s32(a.simdInternal_, b.simdInternal_) };
427 static inline SimdFInt32 gmx_simdcall operator*(SimdFInt32 a, SimdFInt32 b)
429 return { vmulq_s32(a.simdInternal_, b.simdInternal_) };
432 static inline SimdFIBool gmx_simdcall operator==(SimdFInt32 a, SimdFInt32 b)
434 return { vceqq_s32(a.simdInternal_, b.simdInternal_) };
437 static inline SimdFIBool gmx_simdcall testBits(SimdFInt32 a)
439 return { vtstq_s32(a.simdInternal_, a.simdInternal_) };
442 static inline SimdFIBool gmx_simdcall operator<(SimdFInt32 a, SimdFInt32 b)
444 return { vcltq_s32(a.simdInternal_, b.simdInternal_) };
447 static inline SimdFIBool gmx_simdcall operator&&(SimdFIBool a, SimdFIBool b)
449 return { vandq_u32(a.simdInternal_, b.simdInternal_) };
452 static inline SimdFIBool gmx_simdcall operator||(SimdFIBool a, SimdFIBool b)
454 return { vorrq_u32(a.simdInternal_, b.simdInternal_) };
457 static inline SimdFInt32 gmx_simdcall selectByMask(SimdFInt32 a, SimdFIBool m)
459 return { vandq_s32(a.simdInternal_, vreinterpretq_s32_u32(m.simdInternal_)) };
462 static inline SimdFInt32 gmx_simdcall selectByNotMask(SimdFInt32 a, SimdFIBool m)
464 return { vbicq_s32(a.simdInternal_, vreinterpretq_s32_u32(m.simdInternal_)) };
467 static inline SimdFInt32 gmx_simdcall blend(SimdFInt32 a, SimdFInt32 b, SimdFIBool sel)
469 return { vbslq_s32(sel.simdInternal_, b.simdInternal_, a.simdInternal_) };
472 static inline SimdFInt32 gmx_simdcall cvttR2I(SimdFloat a)
474 return { vcvtq_s32_f32(a.simdInternal_) };
477 static inline SimdFloat gmx_simdcall cvtI2R(SimdFInt32 a)
479 return { vcvtq_f32_s32(a.simdInternal_) };
482 static inline SimdFIBool gmx_simdcall cvtB2IB(SimdFBool a)
484 return { a.simdInternal_ };
487 static inline SimdFBool gmx_simdcall cvtIB2B(SimdFIBool a)
489 return { a.simdInternal_ };
492 static inline SimdFloat gmx_simdcall fma(SimdFloat a, SimdFloat b, SimdFloat c)
494 return { vfmaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_) };
497 static inline SimdFloat gmx_simdcall fms(SimdFloat a, SimdFloat b, SimdFloat c)
499 return { vnegq_f32(vfmsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)) };
502 static inline SimdFloat gmx_simdcall fnma(SimdFloat a, SimdFloat b, SimdFloat c)
504 return { vfmsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_) };
507 static inline SimdFloat gmx_simdcall fnms(SimdFloat a, SimdFloat b, SimdFloat c)
509 return { vnegq_f32(vfmaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)) };
512 static inline SimdFloat gmx_simdcall round(SimdFloat x)
514 return { vrndnq_f32(x.simdInternal_) };
517 static inline SimdFloat gmx_simdcall trunc(SimdFloat x)
519 return { vrndq_f32(x.simdInternal_) };
522 static inline SimdFInt32 gmx_simdcall cvtR2I(SimdFloat a)
524 return { vcvtnq_s32_f32(a.simdInternal_) };
527 static inline bool gmx_simdcall anyTrue(SimdFBool a)
529 return (vmaxvq_u32(a.simdInternal_) != 0);
532 static inline bool gmx_simdcall anyTrue(SimdFIBool a)
534 return (vmaxvq_u32(a.simdInternal_) != 0);
537 static inline float gmx_simdcall reduce(SimdFloat a)
539 float32x4_t b = a.simdInternal_;
540 b = vpaddq_f32(b, b);
541 b = vpaddq_f32(b, b);
542 return vgetq_lane_f32(b, 0);
547 #endif // GMX_SIMD_IMPL_ARM_NEON_ASIMD_SIMD_FLOAT_H