Apply clang-format-11
[alexxy/gromacs.git] / src / gromacs / simd / impl_arm_neon_asimd / impl_arm_neon_asimd_util_float.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014,2015,2016,2017,2018 by the GROMACS development team.
5  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36 #ifndef GMX_SIMD_IMPL_ARM_NEON_ASIMD_UTIL_FLOAT_H
37 #define GMX_SIMD_IMPL_ARM_NEON_ASIMD_UTIL_FLOAT_H
38
39 #include "config.h"
40
41 #include <cassert>
42 #include <cstddef>
43 #include <cstdint>
44
45 #include <arm_neon.h>
46
47 #include "gromacs/utility/basedefinitions.h"
48
49
50 namespace gmx
51 {
52
53 template<int align>
54 static inline void gmx_simdcall gatherLoadTranspose(const float*       base,
55                                                     const std::int32_t offset[],
56                                                     SimdFloat*         v0,
57                                                     SimdFloat*         v1,
58                                                     SimdFloat*         v2,
59                                                     SimdFloat*         v3)
60 {
61     assert(std::size_t(offset) % 16 == 0);
62     assert(std::size_t(base) % 16 == 0);
63     assert(align % 4 == 0);
64
65     // Unfortunately we cannot use the beautiful Neon structured load
66     // instructions since the data comes from four different memory locations.
67     float32x4x2_t t0 =
68             vuzpq_f32(vld1q_f32(base + align * offset[0]), vld1q_f32(base + align * offset[2]));
69     float32x4x2_t t1 =
70             vuzpq_f32(vld1q_f32(base + align * offset[1]), vld1q_f32(base + align * offset[3]));
71     float32x4x2_t t2  = vtrnq_f32(t0.val[0], t1.val[0]);
72     float32x4x2_t t3  = vtrnq_f32(t0.val[1], t1.val[1]);
73     v0->simdInternal_ = t2.val[0];
74     v1->simdInternal_ = t3.val[0];
75     v2->simdInternal_ = t2.val[1];
76     v3->simdInternal_ = t3.val[1];
77 }
78
79 template<int align>
80 static inline void gmx_simdcall
81 gatherLoadTranspose(const float* base, const std::int32_t offset[], SimdFloat* v0, SimdFloat* v1)
82 {
83     assert(std::size_t(offset) % 16 == 0);
84     assert(std::size_t(base) % 8 == 0);
85     assert(align % 2 == 0);
86
87     v0->simdInternal_ =
88             vcombine_f32(vld1_f32(base + align * offset[0]), vld1_f32(base + align * offset[2]));
89     v1->simdInternal_ =
90             vcombine_f32(vld1_f32(base + align * offset[1]), vld1_f32(base + align * offset[3]));
91
92     float32x4x2_t tmp = vtrnq_f32(v0->simdInternal_, v1->simdInternal_);
93
94     v0->simdInternal_ = tmp.val[0];
95     v1->simdInternal_ = tmp.val[1];
96 }
97
98 static const int c_simdBestPairAlignmentFloat = 2;
99
100 template<int align>
101 static inline void gmx_simdcall gatherLoadUTranspose(const float*       base,
102                                                      const std::int32_t offset[],
103                                                      SimdFloat*         v0,
104                                                      SimdFloat*         v1,
105                                                      SimdFloat*         v2)
106 {
107     assert(std::size_t(offset) % 16 == 0);
108
109     float32x4x2_t t0 =
110             vuzpq_f32(vld1q_f32(base + align * offset[0]), vld1q_f32(base + align * offset[2]));
111     float32x4x2_t t1 =
112             vuzpq_f32(vld1q_f32(base + align * offset[1]), vld1q_f32(base + align * offset[3]));
113     float32x4x2_t t2  = vtrnq_f32(t0.val[0], t1.val[0]);
114     float32x4x2_t t3  = vtrnq_f32(t0.val[1], t1.val[1]);
115     v0->simdInternal_ = t2.val[0];
116     v1->simdInternal_ = t3.val[0];
117     v2->simdInternal_ = t2.val[1];
118 }
119
120
121 template<int align>
122 static inline void gmx_simdcall
123 transposeScatterStoreU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
124 {
125     assert(std::size_t(offset) % 16 == 0);
126
127     float32x4x2_t tmp = vtrnq_f32(v0.simdInternal_, v1.simdInternal_);
128
129     vst1_f32(base + align * offset[0], vget_low_f32(tmp.val[0]));
130     vst1_f32(base + align * offset[1], vget_low_f32(tmp.val[1]));
131     vst1_f32(base + align * offset[2], vget_high_f32(tmp.val[0]));
132     vst1_f32(base + align * offset[3], vget_high_f32(tmp.val[1]));
133
134     vst1q_lane_f32(base + align * offset[0] + 2, v2.simdInternal_, 0);
135     vst1q_lane_f32(base + align * offset[1] + 2, v2.simdInternal_, 1);
136     vst1q_lane_f32(base + align * offset[2] + 2, v2.simdInternal_, 2);
137     vst1q_lane_f32(base + align * offset[3] + 2, v2.simdInternal_, 3);
138 }
139
140
141 template<int align>
142 static inline void gmx_simdcall
143 transposeScatterIncrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
144 {
145     assert(std::size_t(offset) % 16 == 0);
146
147     if (align < 4)
148     {
149         float32x2_t   t0, t1, t2, t3;
150         float32x4x2_t tmp = vtrnq_f32(v0.simdInternal_, v1.simdInternal_);
151
152         t0 = vget_low_f32(tmp.val[0]);
153         t1 = vget_low_f32(tmp.val[1]);
154         t2 = vget_high_f32(tmp.val[0]);
155         t3 = vget_high_f32(tmp.val[1]);
156
157         t0 = vadd_f32(t0, vld1_f32(base + align * offset[0]));
158         vst1_f32(base + align * offset[0], t0);
159         base[align * offset[0] + 2] += vgetq_lane_f32(v2.simdInternal_, 0);
160
161         t1 = vadd_f32(t1, vld1_f32(base + align * offset[1]));
162         vst1_f32(base + align * offset[1], t1);
163         base[align * offset[1] + 2] += vgetq_lane_f32(v2.simdInternal_, 1);
164
165         t2 = vadd_f32(t2, vld1_f32(base + align * offset[2]));
166         vst1_f32(base + align * offset[2], t2);
167         base[align * offset[2] + 2] += vgetq_lane_f32(v2.simdInternal_, 2);
168
169         t3 = vadd_f32(t3, vld1_f32(base + align * offset[3]));
170         vst1_f32(base + align * offset[3], t3);
171         base[align * offset[3] + 2] += vgetq_lane_f32(v2.simdInternal_, 3);
172     }
173     else
174     {
175         // Extra elements means we can use full width-4 load/store operations
176         float32x4x2_t t0 = vuzpq_f32(v0.simdInternal_, v2.simdInternal_);
177         float32x4x2_t t1 = vuzpq_f32(v1.simdInternal_, vdupq_n_f32(0.0F));
178         float32x4x2_t t2 = vtrnq_f32(t0.val[0], t1.val[0]);
179         float32x4x2_t t3 = vtrnq_f32(t0.val[1], t1.val[1]);
180         float32x4_t   t4 = t2.val[0];
181         float32x4_t   t5 = t3.val[0];
182         float32x4_t   t6 = t2.val[1];
183         float32x4_t   t7 = t3.val[1];
184
185         vst1q_f32(base + align * offset[0], vaddq_f32(t4, vld1q_f32(base + align * offset[0])));
186         vst1q_f32(base + align * offset[1], vaddq_f32(t5, vld1q_f32(base + align * offset[1])));
187         vst1q_f32(base + align * offset[2], vaddq_f32(t6, vld1q_f32(base + align * offset[2])));
188         vst1q_f32(base + align * offset[3], vaddq_f32(t7, vld1q_f32(base + align * offset[3])));
189     }
190 }
191
192 template<int align>
193 static inline void gmx_simdcall
194 transposeScatterDecrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
195 {
196     assert(std::size_t(offset) % 16 == 0);
197
198     if (align < 4)
199     {
200         float32x2_t   t0, t1, t2, t3;
201         float32x4x2_t tmp = vtrnq_f32(v0.simdInternal_, v1.simdInternal_);
202
203         t0 = vget_low_f32(tmp.val[0]);
204         t1 = vget_low_f32(tmp.val[1]);
205         t2 = vget_high_f32(tmp.val[0]);
206         t3 = vget_high_f32(tmp.val[1]);
207
208         t0 = vsub_f32(vld1_f32(base + align * offset[0]), t0);
209         vst1_f32(base + align * offset[0], t0);
210         base[align * offset[0] + 2] -= vgetq_lane_f32(v2.simdInternal_, 0);
211
212         t1 = vsub_f32(vld1_f32(base + align * offset[1]), t1);
213         vst1_f32(base + align * offset[1], t1);
214         base[align * offset[1] + 2] -= vgetq_lane_f32(v2.simdInternal_, 1);
215
216         t2 = vsub_f32(vld1_f32(base + align * offset[2]), t2);
217         vst1_f32(base + align * offset[2], t2);
218         base[align * offset[2] + 2] -= vgetq_lane_f32(v2.simdInternal_, 2);
219
220         t3 = vsub_f32(vld1_f32(base + align * offset[3]), t3);
221         vst1_f32(base + align * offset[3], t3);
222         base[align * offset[3] + 2] -= vgetq_lane_f32(v2.simdInternal_, 3);
223     }
224     else
225     {
226         // Extra elements means we can use full width-4 load/store operations
227         float32x4x2_t t0 = vuzpq_f32(v0.simdInternal_, v2.simdInternal_);
228         float32x4x2_t t1 = vuzpq_f32(v1.simdInternal_, vdupq_n_f32(0.0F));
229         float32x4x2_t t2 = vtrnq_f32(t0.val[0], t1.val[0]);
230         float32x4x2_t t3 = vtrnq_f32(t0.val[1], t1.val[1]);
231         float32x4_t   t4 = t2.val[0];
232         float32x4_t   t5 = t3.val[0];
233         float32x4_t   t6 = t2.val[1];
234         float32x4_t   t7 = t3.val[1];
235
236         vst1q_f32(base + align * offset[0], vsubq_f32(vld1q_f32(base + align * offset[0]), t4));
237         vst1q_f32(base + align * offset[1], vsubq_f32(vld1q_f32(base + align * offset[1]), t5));
238         vst1q_f32(base + align * offset[2], vsubq_f32(vld1q_f32(base + align * offset[2]), t6));
239         vst1q_f32(base + align * offset[3], vsubq_f32(vld1q_f32(base + align * offset[3]), t7));
240     }
241 }
242
243 static inline void gmx_simdcall expandScalarsToTriplets(SimdFloat  scalar,
244                                                         SimdFloat* triplets0,
245                                                         SimdFloat* triplets1,
246                                                         SimdFloat* triplets2)
247 {
248     float32x2_t lo, hi;
249     float32x4_t t0, t1, t2, t3;
250
251     lo = vget_low_f32(scalar.simdInternal_);
252     hi = vget_high_f32(scalar.simdInternal_);
253
254     t0 = vdupq_lane_f32(lo, 0);
255     t1 = vdupq_lane_f32(lo, 1);
256     t2 = vdupq_lane_f32(hi, 0);
257     t3 = vdupq_lane_f32(hi, 1);
258
259     triplets0->simdInternal_ = vextq_f32(t0, t1, 1);
260     triplets1->simdInternal_ = vextq_f32(t1, t2, 2);
261     triplets2->simdInternal_ = vextq_f32(t2, t3, 3);
262 }
263
264
265 template<int align>
266 static inline void gmx_simdcall gatherLoadBySimdIntTranspose(const float* base,
267                                                              SimdFInt32   offset,
268                                                              SimdFloat*   v0,
269                                                              SimdFloat*   v1,
270                                                              SimdFloat*   v2,
271                                                              SimdFloat*   v3)
272 {
273     alignas(GMX_SIMD_ALIGNMENT) std::int32_t ioffset[GMX_SIMD_FINT32_WIDTH];
274
275     assert(std::size_t(base) % 16 == 0);
276     assert(align % 4 == 0);
277
278     store(ioffset, offset);
279     gatherLoadTranspose<align>(base, ioffset, v0, v1, v2, v3);
280 }
281
282 template<int align>
283 static inline void gmx_simdcall
284 gatherLoadBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v0, SimdFloat* v1)
285 {
286     alignas(GMX_SIMD_ALIGNMENT) std::int32_t ioffset[GMX_SIMD_FINT32_WIDTH];
287
288     store(ioffset, offset);
289     gatherLoadTranspose<align>(base, ioffset, v0, v1);
290 }
291
292
293 template<int align>
294 static inline void gmx_simdcall
295 gatherLoadUBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v0, SimdFloat* v1)
296 {
297     alignas(GMX_SIMD_ALIGNMENT) std::int32_t ioffset[GMX_SIMD_FINT32_WIDTH];
298
299     store(ioffset, offset);
300     v0->simdInternal_ =
301             vcombine_f32(vld1_f32(base + align * ioffset[0]), vld1_f32(base + align * ioffset[2]));
302     v1->simdInternal_ =
303             vcombine_f32(vld1_f32(base + align * ioffset[1]), vld1_f32(base + align * ioffset[3]));
304     float32x4x2_t tmp = vtrnq_f32(v0->simdInternal_, v1->simdInternal_);
305     v0->simdInternal_ = tmp.val[0];
306     v1->simdInternal_ = tmp.val[1];
307 }
308
309 static inline float gmx_simdcall reduceIncr4ReturnSum(float* m, SimdFloat v0, SimdFloat v1, SimdFloat v2, SimdFloat v3)
310 {
311     assert(std::size_t(m) % 16 == 0);
312
313     float32x4x2_t t0 = vuzpq_f32(v0.simdInternal_, v2.simdInternal_);
314     float32x4x2_t t1 = vuzpq_f32(v1.simdInternal_, v3.simdInternal_);
315     float32x4x2_t t2 = vtrnq_f32(t0.val[0], t1.val[0]);
316     float32x4x2_t t3 = vtrnq_f32(t0.val[1], t1.val[1]);
317     v0.simdInternal_ = t2.val[0];
318     v1.simdInternal_ = t3.val[0];
319     v2.simdInternal_ = t2.val[1];
320     v3.simdInternal_ = t3.val[1];
321
322     v0 = v0 + v1;
323     v2 = v2 + v3;
324     v0 = v0 + v2;
325     v2 = v0 + simdLoad(m);
326     store(m, v2);
327
328     return reduce(v0);
329 }
330
331 } // namespace gmx
332
333 #endif // GMX_SIMD_IMPL_ARM_NEON_ASIMD_UTIL_FLOAT_H