Apply clang-format-11
[alexxy/gromacs.git] / src / gromacs / simd / impl_arm_sve / impl_arm_sve_util_float.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2020 Research Organization for Information Science and Technology (RIST).
5  * Copyright (c) 2020,2021, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36
37 /*
38  * armv8+sve support to GROMACS was contributed by the Research Organization for
39  * Information Science and Technology (RIST).
40  */
41
42 #ifndef GMX_SIMD_IMPL_ARM_SVE_UTIL_FLOAT_H
43 #define GMX_SIMD_IMPL_ARM_SVE_UTIL_FLOAT_H
44
45 #include "config.h"
46
47 #include <cassert>
48 #include <cstddef>
49 #include <cstdint>
50
51 #include <arm_sve.h>
52
53 #include "gromacs/utility/basedefinitions.h"
54
55 #include "impl_arm_sve_simd_float.h"
56
57 #define SVE_FLOAT_HALF_MASK svwhilelt_b32(0, GMX_SIMD_FLOAT_WIDTH / 2)
58 #define SVE_FINT32_HALF_MASK svwhilelt_b32(0, GMX_SIMD_FLOAT_WIDTH / 2)
59
60 #define SVE_FLOAT4_MASK svptrue_pat_b32(SV_VL4)
61 #define SVE_FLOAT3_MASK svptrue_pat_b32(SV_VL3)
62
63 namespace gmx
64 {
65
66 template<int align>
67 static inline void gmx_simdcall
68 gatherLoadBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v0, SimdFloat* v1)
69 {
70     // Base pointer must be aligned to the smaller of 2 elements and float SIMD width
71     assert(std::size_t(base) % 8 == 0);
72     // align parameter must also be a multiple of the above alignment requirement
73     assert(align % 2 == 0);
74
75     if (align < 2)
76     {
77         svbool_t  pg = svptrue_b32();
78         svint32_t offsets;
79         offsets           = svmul_n_s32_x(pg, offset.simdInternal_, align * 4);
80         v0->simdInternal_ = svld1_gather_s32offset_f32(pg, base, offsets);
81         offsets           = svadd_n_s32_x(pg, offsets, 4);
82         v1->simdInternal_ = svld1_gather_s32offset_f32(pg, base, offsets);
83     }
84     else if (2 == align)
85     {
86         svbool_t    pg    = svptrue_b32();
87         svfloat32_t t0    = svreinterpret_f32_u64(svld1_gather_s64index_u64(
88                 svunpklo_b(pg), (uint64_t*)base, svunpklo_s64(offset.simdInternal_)));
89         svfloat32_t t1    = svreinterpret_f32_u64(svld1_gather_s64index_u64(
90                 svunpkhi_b(pg), (uint64_t*)base, svunpkhi_s64(offset.simdInternal_)));
91         v0->simdInternal_ = svuzp1(t0, t1);
92         v1->simdInternal_ = svuzp2(t0, t1);
93     }
94     else
95     {
96         svbool_t    pg      = svptrue_b32();
97         svint32_t   offsets = svmul_n_s32_x(pg, offset.simdInternal_, align / 2);
98         svfloat32_t t0      = svreinterpret_f32_u64(
99                 svld1_gather_s64index_u64(svunpklo_b(pg), (uint64_t*)base, svunpklo_s64(offsets)));
100         svfloat32_t t1 = svreinterpret_f32_u64(
101                 svld1_gather_s64index_u64(svunpkhi_b(pg), (uint64_t*)base, svunpkhi_s64(offsets)));
102         v0->simdInternal_ = svuzp1(t0, t1);
103         v1->simdInternal_ = svuzp2(t0, t1);
104     }
105 }
106
107 template<int align>
108 static inline void gmx_simdcall gatherLoadTranspose(const float*       base,
109                                                     const std::int32_t offset[],
110                                                     SimdFloat*         v0,
111                                                     SimdFloat*         v1,
112                                                     SimdFloat*         v2,
113                                                     SimdFloat*         v3)
114 {
115     assert(std::size_t(offset) % 16 == 0);
116     assert(std::size_t(base) % 16 == 0);
117     assert(align % 4 == 0);
118
119     svint32_t offsets;
120     offsets = svld1_s32(svptrue_b32(), offset);
121     gatherLoadBySimdIntTranspose<align>(base, offsets, v0, v1, v2, v3);
122 }
123
124 template<int align>
125 static inline void gmx_simdcall
126 gatherLoadTranspose(const float* base, const std::int32_t offset[], SimdFloat* v0, SimdFloat* v1)
127 {
128     assert(std::size_t(offset) % 64 == 0);
129     assert(std::size_t(base) % 8 == 0);
130     assert(align % 2 == 0);
131
132     SimdFInt32 offsets;
133     svbool_t   pg         = svptrue_b32();
134     offsets.simdInternal_ = svld1(pg, offset);
135     gatherLoadBySimdIntTranspose<align>(base, offsets, v0, v1);
136 }
137
138 static const int c_simdBestPairAlignmentFloat = 2;
139
140 template<int align>
141 static inline void gmx_simdcall gatherLoadUTranspose(const float*       base,
142                                                      const std::int32_t offset[],
143                                                      SimdFloat*         v0,
144                                                      SimdFloat*         v1,
145                                                      SimdFloat*         v2)
146 {
147     assert(std::size_t(offset) % 16 == 0);
148
149     svint32_t offsets;
150     svbool_t  pg      = svptrue_b32();
151     offsets           = svmul_n_s32_x(pg, svld1_s32(pg, offset), align * 4);
152     v0->simdInternal_ = svld1_gather_s32offset_f32(pg, base, offsets);
153     offsets           = svadd_n_s32_x(pg, offsets, 4);
154     v1->simdInternal_ = svld1_gather_s32offset_f32(pg, base, offsets);
155     offsets           = svadd_n_s32_x(pg, offsets, 4);
156     v2->simdInternal_ = svld1_gather_s32offset_f32(pg, base, offsets);
157 }
158
159
160 template<int align>
161 static inline void gmx_simdcall
162 transposeScatterStoreU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
163 {
164     assert(std::size_t(offset) % 16 == 0);
165
166     svint32_t offsets;
167     svbool_t  pg = svptrue_b32();
168     offsets      = svmul_n_s32_x(pg, svld1_s32(pg, offset), align * 4);
169     svst1_scatter_s32offset_f32(pg, base, offsets, v0.simdInternal_);
170     offsets = svadd_n_s32_x(pg, offsets, 4);
171     svst1_scatter_s32offset_f32(pg, base, offsets, v1.simdInternal_);
172     offsets = svadd_n_s32_x(pg, offsets, 4);
173     svst1_scatter_s32offset_f32(pg, base, offsets, v2.simdInternal_);
174 }
175
176
177 template<int align>
178 static inline void gmx_simdcall
179 transposeScatterIncrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
180 {
181     assert(std::size_t(offset) % 64 == 0);
182
183     svbool_t                          pg = svptrue_b32();
184     svfloat32x3_t                     v;
185     alignas(GMX_SIMD_ALIGNMENT) float tvec[3 * GMX_SIMD_FLOAT_WIDTH];
186     v = svcreate3_f32(v0.simdInternal_, v1.simdInternal_, v2.simdInternal_);
187     svst3_f32(pg, tvec, v);
188     pg = SVE_FLOAT3_MASK;
189     for (int i = 0; i < GMX_SIMD_FLOAT_WIDTH; i++)
190     {
191         svfloat32_t t1 = svld1_f32(pg, base + align * offset[i]);
192         svfloat32_t t2 = svld1_f32(pg, tvec + 3 * i);
193         svfloat32_t t3 = svadd_f32_x(pg, t1, t2);
194         svst1_f32(pg, base + align * offset[i], t3);
195     }
196 }
197
198 template<int align>
199 static inline void gmx_simdcall
200 transposeScatterDecrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
201 {
202     assert(std::size_t(offset) % 16 == 0);
203
204     svbool_t                          pg = svptrue_b32();
205     svfloat32x3_t                     v;
206     alignas(GMX_SIMD_ALIGNMENT) float tvec[3 * GMX_SIMD_FLOAT_WIDTH];
207     v = svcreate3_f32(v0.simdInternal_, v1.simdInternal_, v2.simdInternal_);
208     svst3_f32(pg, tvec, v);
209     pg = SVE_FLOAT3_MASK;
210     for (int i = 0; i < GMX_SIMD_FLOAT_WIDTH; i++)
211     {
212         svfloat32_t t1 = svld1_f32(pg, base + align * offset[i]);
213         svfloat32_t t2 = svld1_f32(pg, tvec + 3 * i);
214         svfloat32_t t3 = svsub_f32_x(pg, t1, t2);
215         svst1_f32(pg, base + align * offset[i], t3);
216     }
217 }
218
219 static inline void gmx_simdcall expandScalarsToTriplets(SimdFloat  scalar,
220                                                         SimdFloat* triplets0,
221                                                         SimdFloat* triplets1,
222                                                         SimdFloat* triplets2)
223 {
224     assert(GMX_SIMD_FLOAT_WIDTH <= 16);
225     uint32_t   ind[48] = { 0,  0,  0,  1,  1,  1,  2,  2,  2,  3,  3,  3,  4,  4,  4,  5,
226                          5,  5,  6,  6,  6,  7,  7,  7,  8,  8,  8,  9,  9,  9,  10, 10,
227                          10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15 };
228     svbool_t   pg;
229     svuint32_t idx;
230
231     pg                       = svptrue_b32();
232     idx                      = svld1_u32(pg, ind);
233     triplets0->simdInternal_ = svtbl_f32(scalar.simdInternal_, idx);
234     idx                      = svld1_u32(pg, ind + GMX_SIMD_FLOAT_WIDTH);
235     triplets1->simdInternal_ = svtbl_f32(scalar.simdInternal_, idx);
236     idx                      = svld1_u32(pg, ind + 2 * GMX_SIMD_FLOAT_WIDTH);
237     triplets2->simdInternal_ = svtbl_f32(scalar.simdInternal_, idx);
238 }
239
240 template<int align>
241 static inline void gmx_simdcall gatherLoadBySimdIntTranspose(const float* base,
242                                                              SimdFInt32   offset,
243                                                              SimdFloat*   v0,
244                                                              SimdFloat*   v1,
245                                                              SimdFloat*   v2,
246                                                              SimdFloat*   v3)
247 {
248     assert(std::size_t(base) % 16 == 0);
249     assert(align % 4 == 0);
250
251     svbool_t pg          = svptrue_b32();
252     offset.simdInternal_ = svmul_n_s32_x(pg, offset.simdInternal_, align);
253     v0->simdInternal_    = svld1_gather_s32index_f32(pg, base, offset.simdInternal_);
254     offset.simdInternal_ = svadd_n_s32_x(pg, offset.simdInternal_, 1);
255     v1->simdInternal_    = svld1_gather_s32index_f32(pg, base, offset.simdInternal_);
256     offset.simdInternal_ = svadd_n_s32_x(pg, offset.simdInternal_, 1);
257     v2->simdInternal_    = svld1_gather_s32index_f32(pg, base, offset.simdInternal_);
258     offset.simdInternal_ = svadd_n_s32_x(pg, offset.simdInternal_, 1);
259     v3->simdInternal_    = svld1_gather_s32index_f32(pg, base, offset.simdInternal_);
260 }
261
262
263 template<int align>
264 static inline void gmx_simdcall
265 gatherLoadUBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v0, SimdFloat* v1)
266 {
267     svbool_t  pg      = svptrue_b32();
268     svint32_t offsets = svmul_n_s32_x(pg, offset.simdInternal_, align * 4);
269     v0->simdInternal_ = svld1_gather_s32offset_f32(pg, base, offsets);
270     offsets           = svadd_n_s32_x(pg, offsets, 4);
271     v1->simdInternal_ = svld1_gather_s32offset_f32(pg, base, offsets);
272 }
273
274 static inline float gmx_simdcall reduceIncr4ReturnSum(float* m, SimdFloat v0, SimdFloat v1, SimdFloat v2, SimdFloat v3)
275 {
276     assert(std::size_t(m) % 16 == 0);
277     svbool_t    pg = svptrue_b32();
278     svfloat32_t _m, _s;
279     float32_t   sum[4];
280     sum[0] = svadda_f32(pg, 0.0f, v0.simdInternal_);
281     sum[1] = svadda_f32(pg, 0.0f, v1.simdInternal_);
282     sum[2] = svadda_f32(pg, 0.0f, v2.simdInternal_);
283     sum[3] = svadda_f32(pg, 0.0f, v3.simdInternal_);
284     pg     = SVE_FLOAT4_MASK;
285     _m     = svld1_f32(pg, m);
286     _s     = svld1_f32(pg, sum);
287     svst1_f32(pg, m, svadd_f32_x(pg, _m, _s));
288     return svadda_f32(pg, 0.0f, _s);
289 }
290
291 static inline SimdFloat gmx_simdcall loadDualHsimd(const float* m0, const float* m1)
292 {
293     svfloat32_t v0, v1;
294     svbool_t    pg = SVE_FLOAT_HALF_MASK;
295     v0             = svld1_f32(pg, m0);
296     v1             = svld1_f32(pg, m1);
297     return { svsplice_f32(pg, v0, v1) };
298 }
299
300 static inline SimdFloat gmx_simdcall loadDuplicateHsimd(const float* m)
301 {
302     svfloat32_t v;
303     svbool_t    pg = SVE_FLOAT_HALF_MASK;
304     v              = svld1_f32(pg, m);
305     return { svsplice_f32(pg, v, v) };
306 }
307
308 static inline SimdFloat gmx_simdcall loadU1DualHsimd(const float* m)
309 {
310     svfloat32_t v0, v1;
311     svbool_t    pg = SVE_FLOAT_HALF_MASK;
312     v0             = svdup_f32(m[0]);
313     v1             = svdup_f32(m[1]);
314     return { svsplice_f32(pg, v0, v1) };
315 }
316
317 static inline void gmx_simdcall storeDualHsimd(float* m0, float* m1, SimdFloat a)
318 {
319     svbool_t pg = SVE_FLOAT_HALF_MASK;
320     svst1_f32(pg, m0, a.simdInternal_);
321     svst1_f32(pg, m1, svext_f32(a.simdInternal_, a.simdInternal_, GMX_SIMD_FLOAT_WIDTH / 2));
322 }
323
324 static inline void gmx_simdcall incrDualHsimd(float* m0, float* m1, SimdFloat a)
325 {
326     // Make sure the memory pointer is aligned to half float SIMD width
327     assert(std::size_t(m0) % (GMX_SIMD_FLOAT_WIDTH * sizeof(float) / 2) == 0);
328     assert(std::size_t(m1) % (GMX_SIMD_FLOAT_WIDTH * sizeof(float) / 2) == 0);
329
330     svbool_t    pg = SVE_FLOAT_HALF_MASK;
331     svfloat32_t v0, v2, v3;
332     v0 = svld1_f32(pg, m0);
333     v2 = svadd_f32_x(pg, v0, a.simdInternal_);
334     svst1_f32(pg, m0, v2);
335     v0 = svld1_f32(pg, m1);
336     v3 = svext_f32(a.simdInternal_, a.simdInternal_, GMX_SIMD_FLOAT_WIDTH / 2);
337     v2 = svadd_f32_x(pg, v0, v3);
338     svst1_f32(pg, m1, v2);
339 }
340
341 static inline void gmx_simdcall decr3Hsimd(float* m, SimdFloat a0, SimdFloat a1, SimdFloat a2)
342 {
343     svbool_t    pg  = svptrue_b32();
344     svbool_t    pg2 = SVE_FLOAT_HALF_MASK;
345     svfloat32_t v0, v1, v2, v3;
346     v0 = svld1_f32(pg, m);
347     v1 = svext_f32(a0.simdInternal_, a1.simdInternal_, GMX_SIMD_FLOAT_WIDTH / 2);
348     v2 = svsel_f32(pg2, a0.simdInternal_, a1.simdInternal_);
349     v1 = svadd_f32_x(pg, v1, v2);
350     v0 = svsub_f32_z(pg, v0, v1);
351     svst1_f32(pg, m, v0);
352     v0 = svld1_f32(pg2, m + GMX_SIMD_FLOAT_WIDTH);
353     v1 = svext_f32(a2.simdInternal_, a0.simdInternal_, GMX_SIMD_FLOAT_WIDTH / 2);
354     v2 = svadd_f32_x(pg2, a2.simdInternal_, v1);
355     v3 = svsub_f32_x(pg2, v0, v2);
356     svst1_f32(pg2, m + GMX_SIMD_FLOAT_WIDTH, v3);
357 }
358
359 static inline float gmx_simdcall reduceIncr4ReturnSumHsimd(float* m, SimdFloat v0, SimdFloat v1)
360 {
361     svbool_t    pg  = SVE_FLOAT_HALF_MASK;
362     svbool_t    pg2 = sveor_b_z(svptrue_b32(), pg, svptrue_b32());
363     svfloat32_t _m, _s;
364
365     _s = svdup_f32(0.0f);
366     _s = svinsr_n_f32(_s, svaddv_f32(pg2, v1.simdInternal_));
367     _s = svinsr_n_f32(_s, svaddv_f32(pg, v1.simdInternal_));
368     _s = svinsr_n_f32(_s, svaddv_f32(pg2, v0.simdInternal_));
369     _s = svinsr_n_f32(_s, svaddv_f32(pg, v0.simdInternal_));
370
371     pg = SVE_FLOAT4_MASK;
372     _m = svld1_f32(pg, m);
373     svst1_f32(pg, m, svadd_f32_x(pg, _m, _s));
374     return svaddv_f32(pg, _s);
375 }
376
377 template<int align>
378 static inline void gmx_simdcall gatherLoadTransposeHsimd(const float*       base0,
379                                                          const float*       base1,
380                                                          const std::int32_t offset[],
381                                                          SimdFloat*         v0,
382                                                          SimdFloat*         v1)
383 {
384     svint64_t   offsets = svunpklo_s64(svld1_s32(svptrue_b32(), offset));
385     svfloat32_t _v0, _v1;
386     if (2 == align)
387     {
388         _v0 = svreinterpret_f32_f64(svld1_gather_s64index_f64(SVE_DOUBLE_MASK, (double*)base0, offsets));
389         _v1 = svreinterpret_f32_f64(svld1_gather_s64index_f64(SVE_DOUBLE_MASK, (double*)base1, offsets));
390     }
391     else
392     {
393         offsets = svmul_n_s64_x(svptrue_b64(), offsets, align * 4);
394         _v0 = svreinterpret_f32_f64(svld1_gather_s64offset_f64(SVE_DOUBLE_MASK, (double*)base0, offsets));
395         _v1 = svreinterpret_f32_f64(svld1_gather_s64offset_f64(SVE_DOUBLE_MASK, (double*)base1, offsets));
396     }
397     v0->simdInternal_ = svuzp1(_v0, _v1);
398     v1->simdInternal_ = svuzp2(_v0, _v1);
399 }
400
401 } // namespace gmx
402
403 #endif // GMX_SIMD_IMPL_ARM_SVE_UTIL_FLOAT_H