Apply clang-format-11
[alexxy/gromacs.git] / src / gromacs / simd / impl_x86_avx_256 / impl_x86_avx_256_util_float.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014,2015,2017,2018,2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 #ifndef GMX_SIMD_IMPL_X86_AVX_256_UTIL_FLOAT_H
37 #define GMX_SIMD_IMPL_X86_AVX_256_UTIL_FLOAT_H
38
39 #include "config.h"
40
41 #include <cassert>
42 #include <cstddef>
43 #include <cstdint>
44
45 #include <immintrin.h>
46
47 #include "gromacs/utility/basedefinitions.h"
48
49 #include "impl_x86_avx_256_simd_float.h"
50
51 namespace gmx
52 {
53
54 /* This is an internal helper function used by decr3Hsimd(...).
55  */
56 static inline void gmx_simdcall decrHsimd(float* m, SimdFloat a)
57 {
58     assert(std::size_t(m) % 16 == 0);
59     __m128 asum = _mm_add_ps(_mm256_castps256_ps128(a.simdInternal_),
60                              _mm256_extractf128_ps(a.simdInternal_, 0x1));
61     _mm_store_ps(m, _mm_sub_ps(_mm_load_ps(m), asum));
62 }
63
64 /* This is an internal helper function used by the three functions storing,
65  * incrementing, or decrementing data. Do NOT use it outside this file.
66  *
67  * Input v0: [x0 x1 x2 x3 x4 x5 x6 x7]
68  * Input v1: [y0 y1 y2 y3 y4 y5 y6 y7]
69  * Input v2: [z0 z1 z2 z3 z4 z5 z6 z7]
70  * Input v3: Unused
71  *
72  * Output v0: [x0 y0 z0 -  x4 y4 z4 - ]
73  * Output v1: [x1 y1 z1 -  x5 y5 z5 - ]
74  * Output v2: [x2 y2 z2 -  x6 y6 z6 - ]
75  * Output v3: [x3 y3 z3 -  x7 y7 z7 - ]
76  *
77  * Here, - means undefined. Note that such values will not be zero!
78  */
79 static inline void gmx_simdcall avx256Transpose3By4InLanes(__m256* v0, __m256* v1, __m256* v2, __m256* v3)
80 {
81     __m256 t1 = _mm256_unpacklo_ps(*v0, *v1);
82     __m256 t2 = _mm256_unpackhi_ps(*v0, *v1);
83     *v0       = _mm256_shuffle_ps(t1, *v2, _MM_SHUFFLE(0, 0, 1, 0));
84     *v1       = _mm256_shuffle_ps(t1, *v2, _MM_SHUFFLE(0, 1, 3, 2));
85     *v3       = _mm256_shuffle_ps(t2, *v2, _MM_SHUFFLE(0, 3, 3, 2));
86     *v2       = _mm256_shuffle_ps(t2, *v2, _MM_SHUFFLE(0, 2, 1, 0));
87 }
88
89 template<int align>
90 static inline void gmx_simdcall gatherLoadTranspose(const float*       base,
91                                                     const std::int32_t offset[],
92                                                     SimdFloat*         v0,
93                                                     SimdFloat*         v1,
94                                                     SimdFloat*         v2,
95                                                     SimdFloat*         v3)
96 {
97     __m128 t1, t2, t3, t4, t5, t6, t7, t8;
98     __m256 tA, tB, tC, tD;
99
100     assert(std::size_t(offset) % 32 == 0);
101     assert(std::size_t(base) % 16 == 0);
102     assert(align % 4 == 0);
103
104     t1 = _mm_load_ps(base + align * offset[0]);
105     t2 = _mm_load_ps(base + align * offset[1]);
106     t3 = _mm_load_ps(base + align * offset[2]);
107     t4 = _mm_load_ps(base + align * offset[3]);
108     t5 = _mm_load_ps(base + align * offset[4]);
109     t6 = _mm_load_ps(base + align * offset[5]);
110     t7 = _mm_load_ps(base + align * offset[6]);
111     t8 = _mm_load_ps(base + align * offset[7]);
112
113     v0->simdInternal_ = _mm256_insertf128_ps(_mm256_castps128_ps256(t1), t5, 0x1);
114     v1->simdInternal_ = _mm256_insertf128_ps(_mm256_castps128_ps256(t2), t6, 0x1);
115     v2->simdInternal_ = _mm256_insertf128_ps(_mm256_castps128_ps256(t3), t7, 0x1);
116     v3->simdInternal_ = _mm256_insertf128_ps(_mm256_castps128_ps256(t4), t8, 0x1);
117
118     tA = _mm256_unpacklo_ps(v0->simdInternal_, v1->simdInternal_);
119     tB = _mm256_unpacklo_ps(v2->simdInternal_, v3->simdInternal_);
120     tC = _mm256_unpackhi_ps(v0->simdInternal_, v1->simdInternal_);
121     tD = _mm256_unpackhi_ps(v2->simdInternal_, v3->simdInternal_);
122
123     v0->simdInternal_ = _mm256_shuffle_ps(tA, tB, _MM_SHUFFLE(1, 0, 1, 0));
124     v1->simdInternal_ = _mm256_shuffle_ps(tA, tB, _MM_SHUFFLE(3, 2, 3, 2));
125     v2->simdInternal_ = _mm256_shuffle_ps(tC, tD, _MM_SHUFFLE(1, 0, 1, 0));
126     v3->simdInternal_ = _mm256_shuffle_ps(tC, tD, _MM_SHUFFLE(3, 2, 3, 2));
127 }
128
129 template<int align>
130 static inline void gmx_simdcall
131 gatherLoadTranspose(const float* base, const std::int32_t offset[], SimdFloat* v0, SimdFloat* v1)
132 {
133     __m128 t1, t2, t3, t4, t5, t6, t7, t8;
134     __m256 tA, tB, tC, tD;
135
136     assert(std::size_t(offset) % 32 == 0);
137     assert(std::size_t(base) % 8 == 0);
138     assert(align % 2 == 0);
139
140     t1 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[0]));
141     t2 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[1]));
142     t3 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[2]));
143     t4 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[3]));
144     t5 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[4]));
145     t6 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[5]));
146     t7 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[6]));
147     t8 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[7]));
148
149     tA = _mm256_insertf128_ps(_mm256_castps128_ps256(t1), t5, 0x1);
150     tB = _mm256_insertf128_ps(_mm256_castps128_ps256(t2), t6, 0x1);
151     tC = _mm256_insertf128_ps(_mm256_castps128_ps256(t3), t7, 0x1);
152     tD = _mm256_insertf128_ps(_mm256_castps128_ps256(t4), t8, 0x1);
153
154     tA                = _mm256_unpacklo_ps(tA, tC);
155     tB                = _mm256_unpacklo_ps(tB, tD);
156     v0->simdInternal_ = _mm256_unpacklo_ps(tA, tB);
157     v1->simdInternal_ = _mm256_unpackhi_ps(tA, tB);
158 }
159
160 static const int c_simdBestPairAlignmentFloat = 2;
161
162 // With the implementation below, thread-sanitizer can detect false positives.
163 // For loading a triplet, we load 4 floats and ignore the last. Another thread
164 // might write to this element, but that will not affect the result.
165 // On AVX2 we can use a gather intrinsic instead.
166 template<int align>
167 static inline void gmx_simdcall gatherLoadUTranspose(const float*       base,
168                                                      const std::int32_t offset[],
169                                                      SimdFloat*         v0,
170                                                      SimdFloat*         v1,
171                                                      SimdFloat*         v2)
172 {
173     __m256 t1, t2, t3, t4, t5, t6, t7, t8;
174
175     assert(std::size_t(offset) % 32 == 0);
176
177     if (align % 4 == 0)
178     {
179         // we can use aligned loads since base should also be aligned in this case
180         assert(std::size_t(base) % 16 == 0);
181         t1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(base + align * offset[0])),
182                                   _mm_load_ps(base + align * offset[4]),
183                                   0x1);
184         t2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(base + align * offset[1])),
185                                   _mm_load_ps(base + align * offset[5]),
186                                   0x1);
187         t3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(base + align * offset[2])),
188                                   _mm_load_ps(base + align * offset[6]),
189                                   0x1);
190         t4 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(base + align * offset[3])),
191                                   _mm_load_ps(base + align * offset[7]),
192                                   0x1);
193     }
194     else
195     {
196         // Use unaligned loads
197         t1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(base + align * offset[0])),
198                                   _mm_loadu_ps(base + align * offset[4]),
199                                   0x1);
200         t2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(base + align * offset[1])),
201                                   _mm_loadu_ps(base + align * offset[5]),
202                                   0x1);
203         t3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(base + align * offset[2])),
204                                   _mm_loadu_ps(base + align * offset[6]),
205                                   0x1);
206         t4 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(base + align * offset[3])),
207                                   _mm_loadu_ps(base + align * offset[7]),
208                                   0x1);
209     }
210
211     t5                = _mm256_unpacklo_ps(t1, t2);
212     t6                = _mm256_unpacklo_ps(t3, t4);
213     t7                = _mm256_unpackhi_ps(t1, t2);
214     t8                = _mm256_unpackhi_ps(t3, t4);
215     v0->simdInternal_ = _mm256_shuffle_ps(t5, t6, _MM_SHUFFLE(1, 0, 1, 0));
216     v1->simdInternal_ = _mm256_shuffle_ps(t5, t6, _MM_SHUFFLE(3, 2, 3, 2));
217     v2->simdInternal_ = _mm256_shuffle_ps(t7, t8, _MM_SHUFFLE(1, 0, 1, 0));
218 }
219
220 template<int align>
221 static inline void gmx_simdcall
222 transposeScatterStoreU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
223 {
224     __m256  tv3;
225     __m128i mask = _mm_set_epi32(0, -1, -1, -1);
226
227     assert(std::size_t(offset) % 32 == 0);
228
229     avx256Transpose3By4InLanes(&v0.simdInternal_, &v1.simdInternal_, &v2.simdInternal_, &tv3);
230     _mm_maskstore_ps(base + align * offset[0], mask, _mm256_castps256_ps128(v0.simdInternal_));
231     _mm_maskstore_ps(base + align * offset[1], mask, _mm256_castps256_ps128(v1.simdInternal_));
232     _mm_maskstore_ps(base + align * offset[2], mask, _mm256_castps256_ps128(v2.simdInternal_));
233     _mm_maskstore_ps(base + align * offset[3], mask, _mm256_castps256_ps128(tv3));
234     _mm_maskstore_ps(base + align * offset[4], mask, _mm256_extractf128_ps(v0.simdInternal_, 0x1));
235     _mm_maskstore_ps(base + align * offset[5], mask, _mm256_extractf128_ps(v1.simdInternal_, 0x1));
236     _mm_maskstore_ps(base + align * offset[6], mask, _mm256_extractf128_ps(v2.simdInternal_, 0x1));
237     _mm_maskstore_ps(base + align * offset[7], mask, _mm256_extractf128_ps(tv3, 0x1));
238 }
239
240 template<int align>
241 static inline void gmx_simdcall
242 transposeScatterIncrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
243 {
244     __m256 t1, t2, t3, t4, t5, t6, t7, t8, t9, t10;
245     __m128 tA, tB, tC, tD, tE, tF, tG, tH, tX;
246
247     if (align < 4)
248     {
249         t5  = _mm256_unpacklo_ps(v1.simdInternal_, v2.simdInternal_);
250         t6  = _mm256_unpackhi_ps(v1.simdInternal_, v2.simdInternal_);
251         t7  = _mm256_shuffle_ps(v0.simdInternal_, t5, _MM_SHUFFLE(1, 0, 0, 0));
252         t8  = _mm256_shuffle_ps(v0.simdInternal_, t5, _MM_SHUFFLE(3, 2, 0, 1));
253         t9  = _mm256_shuffle_ps(v0.simdInternal_, t6, _MM_SHUFFLE(1, 0, 0, 2));
254         t10 = _mm256_shuffle_ps(v0.simdInternal_, t6, _MM_SHUFFLE(3, 2, 0, 3));
255
256         tA = _mm256_castps256_ps128(t7);
257         tB = _mm256_castps256_ps128(t8);
258         tC = _mm256_castps256_ps128(t9);
259         tD = _mm256_castps256_ps128(t10);
260         tE = _mm256_extractf128_ps(t7, 0x1);
261         tF = _mm256_extractf128_ps(t8, 0x1);
262         tG = _mm256_extractf128_ps(t9, 0x1);
263         tH = _mm256_extractf128_ps(t10, 0x1);
264
265         tX = _mm_load_ss(base + align * offset[0]);
266         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[0] + 1));
267         tX = _mm_add_ps(tX, tA);
268         _mm_store_ss(base + align * offset[0], tX);
269         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[0] + 1), tX);
270
271         tX = _mm_load_ss(base + align * offset[1]);
272         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[1] + 1));
273         tX = _mm_add_ps(tX, tB);
274         _mm_store_ss(base + align * offset[1], tX);
275         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[1] + 1), tX);
276
277         tX = _mm_load_ss(base + align * offset[2]);
278         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[2] + 1));
279         tX = _mm_add_ps(tX, tC);
280         _mm_store_ss(base + align * offset[2], tX);
281         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[2] + 1), tX);
282
283         tX = _mm_load_ss(base + align * offset[3]);
284         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[3] + 1));
285         tX = _mm_add_ps(tX, tD);
286         _mm_store_ss(base + align * offset[3], tX);
287         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[3] + 1), tX);
288
289         tX = _mm_load_ss(base + align * offset[4]);
290         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[4] + 1));
291         tX = _mm_add_ps(tX, tE);
292         _mm_store_ss(base + align * offset[4], tX);
293         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[4] + 1), tX);
294
295         tX = _mm_load_ss(base + align * offset[5]);
296         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[5] + 1));
297         tX = _mm_add_ps(tX, tF);
298         _mm_store_ss(base + align * offset[5], tX);
299         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[5] + 1), tX);
300
301         tX = _mm_load_ss(base + align * offset[6]);
302         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[6] + 1));
303         tX = _mm_add_ps(tX, tG);
304         _mm_store_ss(base + align * offset[6], tX);
305         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[6] + 1), tX);
306
307         tX = _mm_load_ss(base + align * offset[7]);
308         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[7] + 1));
309         tX = _mm_add_ps(tX, tH);
310         _mm_store_ss(base + align * offset[7], tX);
311         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[7] + 1), tX);
312     }
313     else
314     {
315         // Extra elements means we can use full width-4 load/store operations
316         t1 = _mm256_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
317         t2 = _mm256_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
318         t3 = _mm256_unpacklo_ps(v1.simdInternal_, _mm256_setzero_ps());
319         t4 = _mm256_unpackhi_ps(v1.simdInternal_, _mm256_setzero_ps());
320         t5 = _mm256_unpacklo_ps(t1, t3); // x0 y0 z0  0 | x4 y4 z4 0
321         t6 = _mm256_unpackhi_ps(t1, t3); // x1 y1 z1  0 | x5 y5 z5 0
322         t7 = _mm256_unpacklo_ps(t2, t4); // x2 y2 z2  0 | x6 y6 z6 0
323         t8 = _mm256_unpackhi_ps(t2, t4); // x3 y3 z3  0 | x7 y7 z7 0
324
325         if (align % 4 == 0)
326         {
327             // We can use aligned load & store
328             _mm_store_ps(base + align * offset[0],
329                          _mm_add_ps(_mm_load_ps(base + align * offset[0]), _mm256_castps256_ps128(t5)));
330             _mm_store_ps(base + align * offset[1],
331                          _mm_add_ps(_mm_load_ps(base + align * offset[1]), _mm256_castps256_ps128(t6)));
332             _mm_store_ps(base + align * offset[2],
333                          _mm_add_ps(_mm_load_ps(base + align * offset[2]), _mm256_castps256_ps128(t7)));
334             _mm_store_ps(base + align * offset[3],
335                          _mm_add_ps(_mm_load_ps(base + align * offset[3]), _mm256_castps256_ps128(t8)));
336             _mm_store_ps(base + align * offset[4],
337                          _mm_add_ps(_mm_load_ps(base + align * offset[4]), _mm256_extractf128_ps(t5, 0x1)));
338             _mm_store_ps(base + align * offset[5],
339                          _mm_add_ps(_mm_load_ps(base + align * offset[5]), _mm256_extractf128_ps(t6, 0x1)));
340             _mm_store_ps(base + align * offset[6],
341                          _mm_add_ps(_mm_load_ps(base + align * offset[6]), _mm256_extractf128_ps(t7, 0x1)));
342             _mm_store_ps(base + align * offset[7],
343                          _mm_add_ps(_mm_load_ps(base + align * offset[7]), _mm256_extractf128_ps(t8, 0x1)));
344         }
345         else
346         {
347             // alignment >=5, but not a multiple of 4
348             _mm_storeu_ps(base + align * offset[0],
349                           _mm_add_ps(_mm_loadu_ps(base + align * offset[0]), _mm256_castps256_ps128(t5)));
350             _mm_storeu_ps(base + align * offset[1],
351                           _mm_add_ps(_mm_loadu_ps(base + align * offset[1]), _mm256_castps256_ps128(t6)));
352             _mm_storeu_ps(base + align * offset[2],
353                           _mm_add_ps(_mm_loadu_ps(base + align * offset[2]), _mm256_castps256_ps128(t7)));
354             _mm_storeu_ps(base + align * offset[3],
355                           _mm_add_ps(_mm_loadu_ps(base + align * offset[3]), _mm256_castps256_ps128(t8)));
356             _mm_storeu_ps(
357                     base + align * offset[4],
358                     _mm_add_ps(_mm_loadu_ps(base + align * offset[4]), _mm256_extractf128_ps(t5, 0x1)));
359             _mm_storeu_ps(
360                     base + align * offset[5],
361                     _mm_add_ps(_mm_loadu_ps(base + align * offset[5]), _mm256_extractf128_ps(t6, 0x1)));
362             _mm_storeu_ps(
363                     base + align * offset[6],
364                     _mm_add_ps(_mm_loadu_ps(base + align * offset[6]), _mm256_extractf128_ps(t7, 0x1)));
365             _mm_storeu_ps(
366                     base + align * offset[7],
367                     _mm_add_ps(_mm_loadu_ps(base + align * offset[7]), _mm256_extractf128_ps(t8, 0x1)));
368         }
369     }
370 }
371
372 template<int align>
373 static inline void gmx_simdcall
374 transposeScatterDecrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
375 {
376     __m256 t1, t2, t3, t4, t5, t6, t7, t8, t9, t10;
377     __m128 tA, tB, tC, tD, tE, tF, tG, tH, tX;
378
379     if (align < 4)
380     {
381         t5  = _mm256_unpacklo_ps(v1.simdInternal_, v2.simdInternal_);
382         t6  = _mm256_unpackhi_ps(v1.simdInternal_, v2.simdInternal_);
383         t7  = _mm256_shuffle_ps(v0.simdInternal_, t5, _MM_SHUFFLE(1, 0, 0, 0));
384         t8  = _mm256_shuffle_ps(v0.simdInternal_, t5, _MM_SHUFFLE(3, 2, 0, 1));
385         t9  = _mm256_shuffle_ps(v0.simdInternal_, t6, _MM_SHUFFLE(1, 0, 0, 2));
386         t10 = _mm256_shuffle_ps(v0.simdInternal_, t6, _MM_SHUFFLE(3, 2, 0, 3));
387
388         tA = _mm256_castps256_ps128(t7);
389         tB = _mm256_castps256_ps128(t8);
390         tC = _mm256_castps256_ps128(t9);
391         tD = _mm256_castps256_ps128(t10);
392         tE = _mm256_extractf128_ps(t7, 0x1);
393         tF = _mm256_extractf128_ps(t8, 0x1);
394         tG = _mm256_extractf128_ps(t9, 0x1);
395         tH = _mm256_extractf128_ps(t10, 0x1);
396
397         tX = _mm_load_ss(base + align * offset[0]);
398         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[0] + 1));
399         tX = _mm_sub_ps(tX, tA);
400         _mm_store_ss(base + align * offset[0], tX);
401         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[0] + 1), tX);
402
403         tX = _mm_load_ss(base + align * offset[1]);
404         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[1] + 1));
405         tX = _mm_sub_ps(tX, tB);
406         _mm_store_ss(base + align * offset[1], tX);
407         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[1] + 1), tX);
408
409         tX = _mm_load_ss(base + align * offset[2]);
410         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[2] + 1));
411         tX = _mm_sub_ps(tX, tC);
412         _mm_store_ss(base + align * offset[2], tX);
413         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[2] + 1), tX);
414
415         tX = _mm_load_ss(base + align * offset[3]);
416         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[3] + 1));
417         tX = _mm_sub_ps(tX, tD);
418         _mm_store_ss(base + align * offset[3], tX);
419         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[3] + 1), tX);
420
421         tX = _mm_load_ss(base + align * offset[4]);
422         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[4] + 1));
423         tX = _mm_sub_ps(tX, tE);
424         _mm_store_ss(base + align * offset[4], tX);
425         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[4] + 1), tX);
426
427         tX = _mm_load_ss(base + align * offset[5]);
428         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[5] + 1));
429         tX = _mm_sub_ps(tX, tF);
430         _mm_store_ss(base + align * offset[5], tX);
431         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[5] + 1), tX);
432
433         tX = _mm_load_ss(base + align * offset[6]);
434         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[6] + 1));
435         tX = _mm_sub_ps(tX, tG);
436         _mm_store_ss(base + align * offset[6], tX);
437         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[6] + 1), tX);
438
439         tX = _mm_load_ss(base + align * offset[7]);
440         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[7] + 1));
441         tX = _mm_sub_ps(tX, tH);
442         _mm_store_ss(base + align * offset[7], tX);
443         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[7] + 1), tX);
444     }
445     else
446     {
447         // Extra elements means we can use full width-4 load/store operations
448         t1 = _mm256_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
449         t2 = _mm256_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
450         t3 = _mm256_unpacklo_ps(v1.simdInternal_, _mm256_setzero_ps());
451         t4 = _mm256_unpackhi_ps(v1.simdInternal_, _mm256_setzero_ps());
452         t5 = _mm256_unpacklo_ps(t1, t3); // x0 y0 z0  0 | x4 y4 z4 0
453         t6 = _mm256_unpackhi_ps(t1, t3); // x1 y1 z1  0 | x5 y5 z5 0
454         t7 = _mm256_unpacklo_ps(t2, t4); // x2 y2 z2  0 | x6 y6 z6 0
455         t8 = _mm256_unpackhi_ps(t2, t4); // x3 y3 z3  0 | x7 y7 z7 0
456
457         if (align % 4 == 0)
458         {
459             // We can use aligned load & store
460             _mm_store_ps(base + align * offset[0],
461                          _mm_sub_ps(_mm_load_ps(base + align * offset[0]), _mm256_castps256_ps128(t5)));
462             _mm_store_ps(base + align * offset[1],
463                          _mm_sub_ps(_mm_load_ps(base + align * offset[1]), _mm256_castps256_ps128(t6)));
464             _mm_store_ps(base + align * offset[2],
465                          _mm_sub_ps(_mm_load_ps(base + align * offset[2]), _mm256_castps256_ps128(t7)));
466             _mm_store_ps(base + align * offset[3],
467                          _mm_sub_ps(_mm_load_ps(base + align * offset[3]), _mm256_castps256_ps128(t8)));
468             _mm_store_ps(base + align * offset[4],
469                          _mm_sub_ps(_mm_load_ps(base + align * offset[4]), _mm256_extractf128_ps(t5, 0x1)));
470             _mm_store_ps(base + align * offset[5],
471                          _mm_sub_ps(_mm_load_ps(base + align * offset[5]), _mm256_extractf128_ps(t6, 0x1)));
472             _mm_store_ps(base + align * offset[6],
473                          _mm_sub_ps(_mm_load_ps(base + align * offset[6]), _mm256_extractf128_ps(t7, 0x1)));
474             _mm_store_ps(base + align * offset[7],
475                          _mm_sub_ps(_mm_load_ps(base + align * offset[7]), _mm256_extractf128_ps(t8, 0x1)));
476         }
477         else
478         {
479             // alignment >=5, but not a multiple of 4
480             _mm_storeu_ps(base + align * offset[0],
481                           _mm_sub_ps(_mm_loadu_ps(base + align * offset[0]), _mm256_castps256_ps128(t5)));
482             _mm_storeu_ps(base + align * offset[1],
483                           _mm_sub_ps(_mm_loadu_ps(base + align * offset[1]), _mm256_castps256_ps128(t6)));
484             _mm_storeu_ps(base + align * offset[2],
485                           _mm_sub_ps(_mm_loadu_ps(base + align * offset[2]), _mm256_castps256_ps128(t7)));
486             _mm_storeu_ps(base + align * offset[3],
487                           _mm_sub_ps(_mm_loadu_ps(base + align * offset[3]), _mm256_castps256_ps128(t8)));
488             _mm_storeu_ps(
489                     base + align * offset[4],
490                     _mm_sub_ps(_mm_loadu_ps(base + align * offset[4]), _mm256_extractf128_ps(t5, 0x1)));
491             _mm_storeu_ps(
492                     base + align * offset[5],
493                     _mm_sub_ps(_mm_loadu_ps(base + align * offset[5]), _mm256_extractf128_ps(t6, 0x1)));
494             _mm_storeu_ps(
495                     base + align * offset[6],
496                     _mm_sub_ps(_mm_loadu_ps(base + align * offset[6]), _mm256_extractf128_ps(t7, 0x1)));
497             _mm_storeu_ps(
498                     base + align * offset[7],
499                     _mm_sub_ps(_mm_loadu_ps(base + align * offset[7]), _mm256_extractf128_ps(t8, 0x1)));
500         }
501     }
502 }
503
504 static inline void gmx_simdcall expandScalarsToTriplets(SimdFloat  scalar,
505                                                         SimdFloat* triplets0,
506                                                         SimdFloat* triplets1,
507                                                         SimdFloat* triplets2)
508 {
509     __m256 t0 = _mm256_permute2f128_ps(scalar.simdInternal_, scalar.simdInternal_, 0x21);
510     __m256 t1 = _mm256_permute_ps(scalar.simdInternal_, _MM_SHUFFLE(1, 0, 0, 0));
511     __m256 t2 = _mm256_permute_ps(t0, _MM_SHUFFLE(2, 2, 1, 1));
512     __m256 t3 = _mm256_permute_ps(scalar.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
513     triplets0->simdInternal_ = _mm256_blend_ps(t1, t2, 0xF0);
514     triplets1->simdInternal_ = _mm256_blend_ps(t3, t1, 0xF0);
515     triplets2->simdInternal_ = _mm256_blend_ps(t2, t3, 0xF0);
516 }
517
518 template<int align>
519 static inline void gmx_simdcall gatherLoadBySimdIntTranspose(const float* base,
520                                                              SimdFInt32   simdoffset,
521                                                              SimdFloat*   v0,
522                                                              SimdFloat*   v1,
523                                                              SimdFloat*   v2,
524                                                              SimdFloat*   v3)
525 {
526     alignas(GMX_SIMD_ALIGNMENT) std::int32_t offset[GMX_SIMD_FLOAT_WIDTH];
527     _mm256_store_si256(reinterpret_cast<__m256i*>(offset), simdoffset.simdInternal_);
528     gatherLoadTranspose<align>(base, offset, v0, v1, v2, v3);
529 }
530
531 template<int align>
532 static inline void gmx_simdcall
533 gatherLoadBySimdIntTranspose(const float* base, SimdFInt32 simdoffset, SimdFloat* v0, SimdFloat* v1)
534 {
535     alignas(GMX_SIMD_ALIGNMENT) std::int32_t offset[GMX_SIMD_FLOAT_WIDTH];
536     _mm256_store_si256(reinterpret_cast<__m256i*>(offset), simdoffset.simdInternal_);
537     gatherLoadTranspose<align>(base, offset, v0, v1);
538 }
539
540
541 template<int align>
542 static inline void gmx_simdcall
543 gatherLoadUBySimdIntTranspose(const float* base, SimdFInt32 simdoffset, SimdFloat* v0, SimdFloat* v1)
544 {
545     __m128 t1, t2, t3, t4, t5, t6, t7, t8;
546     __m256 tA, tB, tC, tD;
547
548     alignas(GMX_SIMD_ALIGNMENT) std::int32_t offset[GMX_SIMD_FLOAT_WIDTH];
549     _mm256_store_si256(reinterpret_cast<__m256i*>(offset), simdoffset.simdInternal_);
550
551     t1 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[0]));
552     t2 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[1]));
553     t3 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[2]));
554     t4 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[3]));
555     t5 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[4]));
556     t6 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[5]));
557     t7 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[6]));
558     t8 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[7]));
559
560     tA = _mm256_insertf128_ps(_mm256_castps128_ps256(t1), t5, 0x1);
561     tB = _mm256_insertf128_ps(_mm256_castps128_ps256(t2), t6, 0x1);
562     tC = _mm256_insertf128_ps(_mm256_castps128_ps256(t3), t7, 0x1);
563     tD = _mm256_insertf128_ps(_mm256_castps128_ps256(t4), t8, 0x1);
564
565     tA                = _mm256_unpacklo_ps(tA, tC);
566     tB                = _mm256_unpacklo_ps(tB, tD);
567     v0->simdInternal_ = _mm256_unpacklo_ps(tA, tB);
568     v1->simdInternal_ = _mm256_unpackhi_ps(tA, tB);
569 }
570
571 static inline float gmx_simdcall reduceIncr4ReturnSum(float* m, SimdFloat v0, SimdFloat v1, SimdFloat v2, SimdFloat v3)
572 {
573     __m128 t0, t2;
574
575     assert(std::size_t(m) % 16 == 0);
576
577     v0.simdInternal_ = _mm256_hadd_ps(v0.simdInternal_, v1.simdInternal_);
578     v2.simdInternal_ = _mm256_hadd_ps(v2.simdInternal_, v3.simdInternal_);
579     v0.simdInternal_ = _mm256_hadd_ps(v0.simdInternal_, v2.simdInternal_);
580     t0               = _mm_add_ps(_mm256_castps256_ps128(v0.simdInternal_), _mm256_extractf128_ps(v0.simdInternal_, 0x1));
581
582     t2 = _mm_add_ps(t0, _mm_load_ps(m));
583     _mm_store_ps(m, t2);
584
585     t0 = _mm_add_ps(t0, _mm_permute_ps(t0, _MM_SHUFFLE(1, 0, 3, 2)));
586     t0 = _mm_add_ss(t0, _mm_permute_ps(t0, _MM_SHUFFLE(0, 3, 2, 1)));
587     return *reinterpret_cast<float*>(&t0);
588 }
589
590
591 /*************************************
592  * Half-simd-width utility functions *
593  *************************************/
594 static inline SimdFloat gmx_simdcall loadDualHsimd(const float* m0, const float* m1)
595 {
596     assert(std::size_t(m0) % 16 == 0);
597     assert(std::size_t(m1) % 16 == 0);
598
599     return { _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(m0)), _mm_load_ps(m1), 0x1) };
600 }
601
602 static inline SimdFloat gmx_simdcall loadDuplicateHsimd(const float* m)
603 {
604     assert(std::size_t(m) % 16 == 0);
605
606     return { _mm256_broadcast_ps(reinterpret_cast<const __m128*>(m)) };
607 }
608
609 static inline SimdFloat gmx_simdcall loadU1DualHsimd(const float* m)
610 {
611     __m128 t0, t1;
612     t0 = _mm_broadcast_ss(m);
613     t1 = _mm_broadcast_ss(m + 1);
614     return { _mm256_insertf128_ps(_mm256_castps128_ps256(t0), t1, 0x1) };
615 }
616
617
618 static inline void gmx_simdcall storeDualHsimd(float* m0, float* m1, SimdFloat a)
619 {
620     assert(std::size_t(m0) % 16 == 0);
621     assert(std::size_t(m1) % 16 == 0);
622     _mm_store_ps(m0, _mm256_castps256_ps128(a.simdInternal_));
623     _mm_store_ps(m1, _mm256_extractf128_ps(a.simdInternal_, 0x1));
624 }
625
626 static inline void gmx_simdcall incrDualHsimd(float* m0, float* m1, SimdFloat a)
627 {
628     assert(std::size_t(m0) % 16 == 0);
629     assert(std::size_t(m1) % 16 == 0);
630     _mm_store_ps(m0, _mm_add_ps(_mm256_castps256_ps128(a.simdInternal_), _mm_load_ps(m0)));
631     _mm_store_ps(m1, _mm_add_ps(_mm256_extractf128_ps(a.simdInternal_, 0x1), _mm_load_ps(m1)));
632 }
633
634 static inline void gmx_simdcall decr3Hsimd(float* m, SimdFloat a0, SimdFloat a1, SimdFloat a2)
635 {
636     assert(std::size_t(m) % 16 == 0);
637     decrHsimd(m, a0);
638     decrHsimd(m + GMX_SIMD_FLOAT_WIDTH / 2, a1);
639     decrHsimd(m + GMX_SIMD_FLOAT_WIDTH, a2);
640 }
641
642
643 template<int align>
644 static inline void gmx_simdcall gatherLoadTransposeHsimd(const float*       base0,
645                                                          const float*       base1,
646                                                          const std::int32_t offset[],
647                                                          SimdFloat*         v0,
648                                                          SimdFloat*         v1)
649 {
650     __m128 t0, t1, t2, t3, t4, t5, t6, t7;
651     __m256 tA, tB, tC, tD;
652
653     assert(std::size_t(offset) % 16 == 0);
654     assert(std::size_t(base0) % 8 == 0);
655     assert(std::size_t(base1) % 8 == 0);
656     assert(align % 2 == 0);
657
658     t0 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base0 + align * offset[0]));
659     t1 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base0 + align * offset[1]));
660     t2 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base0 + align * offset[2]));
661     t3 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base0 + align * offset[3]));
662     t4 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base1 + align * offset[0]));
663     t5 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base1 + align * offset[1]));
664     t6 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base1 + align * offset[2]));
665     t7 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base1 + align * offset[3]));
666
667     tA = _mm256_insertf128_ps(_mm256_castps128_ps256(t0), t4, 0x1);
668     tB = _mm256_insertf128_ps(_mm256_castps128_ps256(t1), t5, 0x1);
669     tC = _mm256_insertf128_ps(_mm256_castps128_ps256(t2), t6, 0x1);
670     tD = _mm256_insertf128_ps(_mm256_castps128_ps256(t3), t7, 0x1);
671
672     tA                = _mm256_unpacklo_ps(tA, tC);
673     tB                = _mm256_unpacklo_ps(tB, tD);
674     v0->simdInternal_ = _mm256_unpacklo_ps(tA, tB);
675     v1->simdInternal_ = _mm256_unpackhi_ps(tA, tB);
676 }
677
678
679 static inline float gmx_simdcall reduceIncr4ReturnSumHsimd(float* m, SimdFloat v0, SimdFloat v1)
680 {
681     __m128 t0, t1;
682
683     v0.simdInternal_ = _mm256_hadd_ps(v0.simdInternal_, v1.simdInternal_);
684     t0               = _mm256_extractf128_ps(v0.simdInternal_, 0x1);
685     t0               = _mm_hadd_ps(_mm256_castps256_ps128(v0.simdInternal_), t0);
686     t0               = _mm_permute_ps(t0, _MM_SHUFFLE(3, 1, 2, 0));
687
688     assert(std::size_t(m) % 16 == 0);
689
690     t1 = _mm_add_ps(t0, _mm_load_ps(m));
691     _mm_store_ps(m, t1);
692
693     t0 = _mm_add_ps(t0, _mm_permute_ps(t0, _MM_SHUFFLE(1, 0, 3, 2)));
694     t0 = _mm_add_ss(t0, _mm_permute_ps(t0, _MM_SHUFFLE(0, 3, 2, 1)));
695     return *reinterpret_cast<float*>(&t0);
696 }
697
698 static inline SimdFloat gmx_simdcall loadU4NOffset(const float* m, int offset)
699 {
700     return { _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(m)), _mm_loadu_ps(m + offset), 0x1) };
701 }
702
703
704 } // namespace gmx
705
706 #endif // GMX_SIMD_IMPL_X86_AVX_256_UTIL_FLOAT_H