8918e8fe9542ca3957f11113626ad8fa3f48b2d8
[alexxy/gromacs.git] / src / gromacs / simd / impl_arm_neon / impl_arm_neon_util_float.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014,2015,2016,2017,2018,2019, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 #ifndef GMX_SIMD_IMPL_ARM_NEON_UTIL_FLOAT_H
36 #define GMX_SIMD_IMPL_ARM_NEON_UTIL_FLOAT_H
37
38 #include "config.h"
39
40 #include <cassert>
41 #include <cstddef>
42 #include <cstdint>
43
44 #include <arm_neon.h>
45
46 #include "gromacs/utility/basedefinitions.h"
47
48 #include "impl_arm_neon_simd_float.h"
49
50
51 namespace gmx
52 {
53
54 template<int align>
55 static inline void gmx_simdcall gatherLoadTranspose(const float*       base,
56                                                     const std::int32_t offset[],
57                                                     SimdFloat*         v0,
58                                                     SimdFloat*         v1,
59                                                     SimdFloat*         v2,
60                                                     SimdFloat*         v3)
61 {
62     assert(std::size_t(offset) % 16 == 0);
63     assert(std::size_t(base) % 16 == 0);
64     assert(align % 4 == 0);
65
66     // Unfortunately we cannot use the beautiful Neon structured load
67     // instructions since the data comes from four different memory locations.
68     float32x4x2_t t0 =
69             vuzpq_f32(vld1q_f32(base + align * offset[0]), vld1q_f32(base + align * offset[2]));
70     float32x4x2_t t1 =
71             vuzpq_f32(vld1q_f32(base + align * offset[1]), vld1q_f32(base + align * offset[3]));
72     float32x4x2_t t2  = vtrnq_f32(t0.val[0], t1.val[0]);
73     float32x4x2_t t3  = vtrnq_f32(t0.val[1], t1.val[1]);
74     v0->simdInternal_ = t2.val[0];
75     v1->simdInternal_ = t3.val[0];
76     v2->simdInternal_ = t2.val[1];
77     v3->simdInternal_ = t3.val[1];
78 }
79
80 template<int align>
81 static inline void gmx_simdcall
82                    gatherLoadTranspose(const float* base, const std::int32_t offset[], SimdFloat* v0, SimdFloat* v1)
83 {
84     assert(std::size_t(offset) % 16 == 0);
85     assert(std::size_t(base) % 8 == 0);
86     assert(align % 2 == 0);
87
88     v0->simdInternal_ =
89             vcombine_f32(vld1_f32(base + align * offset[0]), vld1_f32(base + align * offset[2]));
90     v1->simdInternal_ =
91             vcombine_f32(vld1_f32(base + align * offset[1]), vld1_f32(base + align * offset[3]));
92
93     float32x4x2_t tmp = vtrnq_f32(v0->simdInternal_, v1->simdInternal_);
94
95     v0->simdInternal_ = tmp.val[0];
96     v1->simdInternal_ = tmp.val[1];
97 }
98
99 static const int c_simdBestPairAlignmentFloat = 2;
100
101 template<int align>
102 static inline void gmx_simdcall gatherLoadUTranspose(const float*       base,
103                                                      const std::int32_t offset[],
104                                                      SimdFloat*         v0,
105                                                      SimdFloat*         v1,
106                                                      SimdFloat*         v2)
107 {
108     assert(std::size_t(offset) % 16 == 0);
109
110     float32x4x2_t t0 =
111             vuzpq_f32(vld1q_f32(base + align * offset[0]), vld1q_f32(base + align * offset[2]));
112     float32x4x2_t t1 =
113             vuzpq_f32(vld1q_f32(base + align * offset[1]), vld1q_f32(base + align * offset[3]));
114     float32x4x2_t t2  = vtrnq_f32(t0.val[0], t1.val[0]);
115     float32x4x2_t t3  = vtrnq_f32(t0.val[1], t1.val[1]);
116     v0->simdInternal_ = t2.val[0];
117     v1->simdInternal_ = t3.val[0];
118     v2->simdInternal_ = t2.val[1];
119 }
120
121
122 template<int align>
123 static inline void gmx_simdcall
124                    transposeScatterStoreU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
125 {
126     assert(std::size_t(offset) % 16 == 0);
127
128     float32x4x2_t tmp = vtrnq_f32(v0.simdInternal_, v1.simdInternal_);
129
130     vst1_f32(base + align * offset[0], vget_low_f32(tmp.val[0]));
131     vst1_f32(base + align * offset[1], vget_low_f32(tmp.val[1]));
132     vst1_f32(base + align * offset[2], vget_high_f32(tmp.val[0]));
133     vst1_f32(base + align * offset[3], vget_high_f32(tmp.val[1]));
134
135     vst1q_lane_f32(base + align * offset[0] + 2, v2.simdInternal_, 0);
136     vst1q_lane_f32(base + align * offset[1] + 2, v2.simdInternal_, 1);
137     vst1q_lane_f32(base + align * offset[2] + 2, v2.simdInternal_, 2);
138     vst1q_lane_f32(base + align * offset[3] + 2, v2.simdInternal_, 3);
139 }
140
141
142 template<int align>
143 static inline void gmx_simdcall
144                    transposeScatterIncrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
145 {
146     assert(std::size_t(offset) % 16 == 0);
147
148     if (align < 4)
149     {
150         float32x2_t   t0, t1, t2, t3;
151         float32x4x2_t tmp = vtrnq_f32(v0.simdInternal_, v1.simdInternal_);
152
153         t0 = vget_low_f32(tmp.val[0]);
154         t1 = vget_low_f32(tmp.val[1]);
155         t2 = vget_high_f32(tmp.val[0]);
156         t3 = vget_high_f32(tmp.val[1]);
157
158         t0 = vadd_f32(t0, vld1_f32(base + align * offset[0]));
159         vst1_f32(base + align * offset[0], t0);
160         base[align * offset[0] + 2] += vgetq_lane_f32(v2.simdInternal_, 0);
161
162         t1 = vadd_f32(t1, vld1_f32(base + align * offset[1]));
163         vst1_f32(base + align * offset[1], t1);
164         base[align * offset[1] + 2] += vgetq_lane_f32(v2.simdInternal_, 1);
165
166         t2 = vadd_f32(t2, vld1_f32(base + align * offset[2]));
167         vst1_f32(base + align * offset[2], t2);
168         base[align * offset[2] + 2] += vgetq_lane_f32(v2.simdInternal_, 2);
169
170         t3 = vadd_f32(t3, vld1_f32(base + align * offset[3]));
171         vst1_f32(base + align * offset[3], t3);
172         base[align * offset[3] + 2] += vgetq_lane_f32(v2.simdInternal_, 3);
173     }
174     else
175     {
176         // Extra elements means we can use full width-4 load/store operations
177         float32x4x2_t t0 = vuzpq_f32(v0.simdInternal_, v2.simdInternal_);
178         float32x4x2_t t1 = vuzpq_f32(v1.simdInternal_, vdupq_n_f32(0.0F));
179         float32x4x2_t t2 = vtrnq_f32(t0.val[0], t1.val[0]);
180         float32x4x2_t t3 = vtrnq_f32(t0.val[1], t1.val[1]);
181         float32x4_t   t4 = t2.val[0];
182         float32x4_t   t5 = t3.val[0];
183         float32x4_t   t6 = t2.val[1];
184         float32x4_t   t7 = t3.val[1];
185
186         vst1q_f32(base + align * offset[0], vaddq_f32(t4, vld1q_f32(base + align * offset[0])));
187         vst1q_f32(base + align * offset[1], vaddq_f32(t5, vld1q_f32(base + align * offset[1])));
188         vst1q_f32(base + align * offset[2], vaddq_f32(t6, vld1q_f32(base + align * offset[2])));
189         vst1q_f32(base + align * offset[3], vaddq_f32(t7, vld1q_f32(base + align * offset[3])));
190     }
191 }
192
193 template<int align>
194 static inline void gmx_simdcall
195                    transposeScatterDecrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
196 {
197     assert(std::size_t(offset) % 16 == 0);
198
199     if (align < 4)
200     {
201         float32x2_t   t0, t1, t2, t3;
202         float32x4x2_t tmp = vtrnq_f32(v0.simdInternal_, v1.simdInternal_);
203
204         t0 = vget_low_f32(tmp.val[0]);
205         t1 = vget_low_f32(tmp.val[1]);
206         t2 = vget_high_f32(tmp.val[0]);
207         t3 = vget_high_f32(tmp.val[1]);
208
209         t0 = vsub_f32(vld1_f32(base + align * offset[0]), t0);
210         vst1_f32(base + align * offset[0], t0);
211         base[align * offset[0] + 2] -= vgetq_lane_f32(v2.simdInternal_, 0);
212
213         t1 = vsub_f32(vld1_f32(base + align * offset[1]), t1);
214         vst1_f32(base + align * offset[1], t1);
215         base[align * offset[1] + 2] -= vgetq_lane_f32(v2.simdInternal_, 1);
216
217         t2 = vsub_f32(vld1_f32(base + align * offset[2]), t2);
218         vst1_f32(base + align * offset[2], t2);
219         base[align * offset[2] + 2] -= vgetq_lane_f32(v2.simdInternal_, 2);
220
221         t3 = vsub_f32(vld1_f32(base + align * offset[3]), t3);
222         vst1_f32(base + align * offset[3], t3);
223         base[align * offset[3] + 2] -= vgetq_lane_f32(v2.simdInternal_, 3);
224     }
225     else
226     {
227         // Extra elements means we can use full width-4 load/store operations
228         float32x4x2_t t0 = vuzpq_f32(v0.simdInternal_, v2.simdInternal_);
229         float32x4x2_t t1 = vuzpq_f32(v1.simdInternal_, vdupq_n_f32(0.0F));
230         float32x4x2_t t2 = vtrnq_f32(t0.val[0], t1.val[0]);
231         float32x4x2_t t3 = vtrnq_f32(t0.val[1], t1.val[1]);
232         float32x4_t   t4 = t2.val[0];
233         float32x4_t   t5 = t3.val[0];
234         float32x4_t   t6 = t2.val[1];
235         float32x4_t   t7 = t3.val[1];
236
237         vst1q_f32(base + align * offset[0], vsubq_f32(vld1q_f32(base + align * offset[0]), t4));
238         vst1q_f32(base + align * offset[1], vsubq_f32(vld1q_f32(base + align * offset[1]), t5));
239         vst1q_f32(base + align * offset[2], vsubq_f32(vld1q_f32(base + align * offset[2]), t6));
240         vst1q_f32(base + align * offset[3], vsubq_f32(vld1q_f32(base + align * offset[3]), t7));
241     }
242 }
243
244 static inline void gmx_simdcall expandScalarsToTriplets(SimdFloat  scalar,
245                                                         SimdFloat* triplets0,
246                                                         SimdFloat* triplets1,
247                                                         SimdFloat* triplets2)
248 {
249     float32x2_t lo, hi;
250     float32x4_t t0, t1, t2, t3;
251
252     lo = vget_low_f32(scalar.simdInternal_);
253     hi = vget_high_f32(scalar.simdInternal_);
254
255     t0 = vdupq_lane_f32(lo, 0);
256     t1 = vdupq_lane_f32(lo, 1);
257     t2 = vdupq_lane_f32(hi, 0);
258     t3 = vdupq_lane_f32(hi, 1);
259
260     triplets0->simdInternal_ = vextq_f32(t0, t1, 1);
261     triplets1->simdInternal_ = vextq_f32(t1, t2, 2);
262     triplets2->simdInternal_ = vextq_f32(t2, t3, 3);
263 }
264
265
266 template<int align>
267 static inline void gmx_simdcall gatherLoadBySimdIntTranspose(const float* base,
268                                                              SimdFInt32   offset,
269                                                              SimdFloat*   v0,
270                                                              SimdFloat*   v1,
271                                                              SimdFloat*   v2,
272                                                              SimdFloat*   v3)
273 {
274     alignas(GMX_SIMD_ALIGNMENT) std::int32_t ioffset[GMX_SIMD_FINT32_WIDTH];
275
276     assert(std::size_t(base) % 16 == 0);
277     assert(align % 4 == 0);
278
279     store(ioffset, offset);
280     gatherLoadTranspose<align>(base, ioffset, v0, v1, v2, v3);
281 }
282
283 template<int align>
284 static inline void gmx_simdcall
285                    gatherLoadBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v0, SimdFloat* v1)
286 {
287     alignas(GMX_SIMD_ALIGNMENT) std::int32_t ioffset[GMX_SIMD_FINT32_WIDTH];
288
289     store(ioffset, offset);
290     gatherLoadTranspose<align>(base, ioffset, v0, v1);
291 }
292
293
294 template<int align>
295 static inline void gmx_simdcall
296                    gatherLoadUBySimdIntTranspose(const float* base, SimdFInt32 offset, SimdFloat* v0, SimdFloat* v1)
297 {
298     alignas(GMX_SIMD_ALIGNMENT) std::int32_t ioffset[GMX_SIMD_FINT32_WIDTH];
299
300     store(ioffset, offset);
301     v0->simdInternal_ =
302             vcombine_f32(vld1_f32(base + align * ioffset[0]), vld1_f32(base + align * ioffset[2]));
303     v1->simdInternal_ =
304             vcombine_f32(vld1_f32(base + align * ioffset[1]), vld1_f32(base + align * ioffset[3]));
305     float32x4x2_t tmp = vtrnq_f32(v0->simdInternal_, v1->simdInternal_);
306     v0->simdInternal_ = tmp.val[0];
307     v1->simdInternal_ = tmp.val[1];
308 }
309
310 static inline float gmx_simdcall reduceIncr4ReturnSum(float* m, SimdFloat v0, SimdFloat v1, SimdFloat v2, SimdFloat v3)
311 {
312     assert(std::size_t(m) % 16 == 0);
313
314     float32x4x2_t t0 = vuzpq_f32(v0.simdInternal_, v2.simdInternal_);
315     float32x4x2_t t1 = vuzpq_f32(v1.simdInternal_, v3.simdInternal_);
316     float32x4x2_t t2 = vtrnq_f32(t0.val[0], t1.val[0]);
317     float32x4x2_t t3 = vtrnq_f32(t0.val[1], t1.val[1]);
318     v0.simdInternal_ = t2.val[0];
319     v1.simdInternal_ = t3.val[0];
320     v2.simdInternal_ = t2.val[1];
321     v3.simdInternal_ = t3.val[1];
322
323     v0 = v0 + v1;
324     v2 = v2 + v3;
325     v0 = v0 + v2;
326     v2 = v0 + simdLoad(m);
327     store(m, v2);
328
329     return reduce(v0);
330 }
331
332 } // namespace gmx
333
334 #endif // GMX_SIMD_IMPL_ARM_NEON_UTIL_FLOAT_H