0e0957a8130962aa7f21cea9155a91213d9fe058
[alexxy/gromacs.git] / src / gromacs / simd / impl_x86_avx_256 / impl_x86_avx_256_util_float.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014,2015,2017,2018,2019,2020, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 #ifndef GMX_SIMD_IMPL_X86_AVX_256_UTIL_FLOAT_H
37 #define GMX_SIMD_IMPL_X86_AVX_256_UTIL_FLOAT_H
38
39 #include "config.h"
40
41 #include <cassert>
42 #include <cstddef>
43 #include <cstdint>
44
45 #include <immintrin.h>
46
47 #include "gromacs/utility/basedefinitions.h"
48
49 #include "impl_x86_avx_256_simd_float.h"
50
51 namespace gmx
52 {
53
54 /* This is an internal helper function used by decr3Hsimd(...).
55  */
56 static inline void gmx_simdcall decrHsimd(float* m, SimdFloat a)
57 {
58     assert(std::size_t(m) % 16 == 0);
59     __m128 asum = _mm_add_ps(_mm256_castps256_ps128(a.simdInternal_),
60                              _mm256_extractf128_ps(a.simdInternal_, 0x1));
61     _mm_store_ps(m, _mm_sub_ps(_mm_load_ps(m), asum));
62 }
63
64 /* This is an internal helper function used by the three functions storing,
65  * incrementing, or decrementing data. Do NOT use it outside this file.
66  *
67  * Input v0: [x0 x1 x2 x3 x4 x5 x6 x7]
68  * Input v1: [y0 y1 y2 y3 y4 y5 y6 y7]
69  * Input v2: [z0 z1 z2 z3 z4 z5 z6 z7]
70  * Input v3: Unused
71  *
72  * Output v0: [x0 y0 z0 -  x4 y4 z4 - ]
73  * Output v1: [x1 y1 z1 -  x5 y5 z5 - ]
74  * Output v2: [x2 y2 z2 -  x6 y6 z6 - ]
75  * Output v3: [x3 y3 z3 -  x7 y7 z7 - ]
76  *
77  * Here, - means undefined. Note that such values will not be zero!
78  */
79 static inline void gmx_simdcall avx256Transpose3By4InLanes(__m256* v0, __m256* v1, __m256* v2, __m256* v3)
80 {
81     __m256 t1 = _mm256_unpacklo_ps(*v0, *v1);
82     __m256 t2 = _mm256_unpackhi_ps(*v0, *v1);
83     *v0       = _mm256_shuffle_ps(t1, *v2, _MM_SHUFFLE(0, 0, 1, 0));
84     *v1       = _mm256_shuffle_ps(t1, *v2, _MM_SHUFFLE(0, 1, 3, 2));
85     *v3       = _mm256_shuffle_ps(t2, *v2, _MM_SHUFFLE(0, 3, 3, 2));
86     *v2       = _mm256_shuffle_ps(t2, *v2, _MM_SHUFFLE(0, 2, 1, 0));
87 }
88
89 template<int align>
90 static inline void gmx_simdcall gatherLoadTranspose(const float*       base,
91                                                     const std::int32_t offset[],
92                                                     SimdFloat*         v0,
93                                                     SimdFloat*         v1,
94                                                     SimdFloat*         v2,
95                                                     SimdFloat*         v3)
96 {
97     __m128 t1, t2, t3, t4, t5, t6, t7, t8;
98     __m256 tA, tB, tC, tD;
99
100     assert(std::size_t(offset) % 32 == 0);
101     assert(std::size_t(base) % 16 == 0);
102     assert(align % 4 == 0);
103
104     t1 = _mm_load_ps(base + align * offset[0]);
105     t2 = _mm_load_ps(base + align * offset[1]);
106     t3 = _mm_load_ps(base + align * offset[2]);
107     t4 = _mm_load_ps(base + align * offset[3]);
108     t5 = _mm_load_ps(base + align * offset[4]);
109     t6 = _mm_load_ps(base + align * offset[5]);
110     t7 = _mm_load_ps(base + align * offset[6]);
111     t8 = _mm_load_ps(base + align * offset[7]);
112
113     v0->simdInternal_ = _mm256_insertf128_ps(_mm256_castps128_ps256(t1), t5, 0x1);
114     v1->simdInternal_ = _mm256_insertf128_ps(_mm256_castps128_ps256(t2), t6, 0x1);
115     v2->simdInternal_ = _mm256_insertf128_ps(_mm256_castps128_ps256(t3), t7, 0x1);
116     v3->simdInternal_ = _mm256_insertf128_ps(_mm256_castps128_ps256(t4), t8, 0x1);
117
118     tA = _mm256_unpacklo_ps(v0->simdInternal_, v1->simdInternal_);
119     tB = _mm256_unpacklo_ps(v2->simdInternal_, v3->simdInternal_);
120     tC = _mm256_unpackhi_ps(v0->simdInternal_, v1->simdInternal_);
121     tD = _mm256_unpackhi_ps(v2->simdInternal_, v3->simdInternal_);
122
123     v0->simdInternal_ = _mm256_shuffle_ps(tA, tB, _MM_SHUFFLE(1, 0, 1, 0));
124     v1->simdInternal_ = _mm256_shuffle_ps(tA, tB, _MM_SHUFFLE(3, 2, 3, 2));
125     v2->simdInternal_ = _mm256_shuffle_ps(tC, tD, _MM_SHUFFLE(1, 0, 1, 0));
126     v3->simdInternal_ = _mm256_shuffle_ps(tC, tD, _MM_SHUFFLE(3, 2, 3, 2));
127 }
128
129 template<int align>
130 static inline void gmx_simdcall
131                    gatherLoadTranspose(const float* base, const std::int32_t offset[], SimdFloat* v0, SimdFloat* v1)
132 {
133     __m128 t1, t2, t3, t4, t5, t6, t7, t8;
134     __m256 tA, tB, tC, tD;
135
136     assert(std::size_t(offset) % 32 == 0);
137     assert(std::size_t(base) % 8 == 0);
138     assert(align % 2 == 0);
139
140     t1 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[0]));
141     t2 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[1]));
142     t3 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[2]));
143     t4 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[3]));
144     t5 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[4]));
145     t6 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[5]));
146     t7 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[6]));
147     t8 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[7]));
148
149     tA = _mm256_insertf128_ps(_mm256_castps128_ps256(t1), t5, 0x1);
150     tB = _mm256_insertf128_ps(_mm256_castps128_ps256(t2), t6, 0x1);
151     tC = _mm256_insertf128_ps(_mm256_castps128_ps256(t3), t7, 0x1);
152     tD = _mm256_insertf128_ps(_mm256_castps128_ps256(t4), t8, 0x1);
153
154     tA                = _mm256_unpacklo_ps(tA, tC);
155     tB                = _mm256_unpacklo_ps(tB, tD);
156     v0->simdInternal_ = _mm256_unpacklo_ps(tA, tB);
157     v1->simdInternal_ = _mm256_unpackhi_ps(tA, tB);
158 }
159
160 static const int c_simdBestPairAlignmentFloat = 2;
161
162 // With the implementation below, thread-sanitizer can detect false positives.
163 // For loading a triplet, we load 4 floats and ignore the last. Another thread
164 // might write to this element, but that will not affect the result.
165 // On AVX2 we can use a gather intrinsic instead.
166 template<int align>
167 static inline void gmx_simdcall gatherLoadUTranspose(const float*       base,
168                                                      const std::int32_t offset[],
169                                                      SimdFloat*         v0,
170                                                      SimdFloat*         v1,
171                                                      SimdFloat*         v2)
172 {
173     __m256 t1, t2, t3, t4, t5, t6, t7, t8;
174
175     assert(std::size_t(offset) % 32 == 0);
176
177     if (align % 4 == 0)
178     {
179         // we can use aligned loads since base should also be aligned in this case
180         assert(std::size_t(base) % 16 == 0);
181         t1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(base + align * offset[0])),
182                                   _mm_load_ps(base + align * offset[4]), 0x1);
183         t2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(base + align * offset[1])),
184                                   _mm_load_ps(base + align * offset[5]), 0x1);
185         t3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(base + align * offset[2])),
186                                   _mm_load_ps(base + align * offset[6]), 0x1);
187         t4 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(base + align * offset[3])),
188                                   _mm_load_ps(base + align * offset[7]), 0x1);
189     }
190     else
191     {
192         // Use unaligned loads
193         t1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(base + align * offset[0])),
194                                   _mm_loadu_ps(base + align * offset[4]), 0x1);
195         t2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(base + align * offset[1])),
196                                   _mm_loadu_ps(base + align * offset[5]), 0x1);
197         t3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(base + align * offset[2])),
198                                   _mm_loadu_ps(base + align * offset[6]), 0x1);
199         t4 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(base + align * offset[3])),
200                                   _mm_loadu_ps(base + align * offset[7]), 0x1);
201     }
202
203     t5                = _mm256_unpacklo_ps(t1, t2);
204     t6                = _mm256_unpacklo_ps(t3, t4);
205     t7                = _mm256_unpackhi_ps(t1, t2);
206     t8                = _mm256_unpackhi_ps(t3, t4);
207     v0->simdInternal_ = _mm256_shuffle_ps(t5, t6, _MM_SHUFFLE(1, 0, 1, 0));
208     v1->simdInternal_ = _mm256_shuffle_ps(t5, t6, _MM_SHUFFLE(3, 2, 3, 2));
209     v2->simdInternal_ = _mm256_shuffle_ps(t7, t8, _MM_SHUFFLE(1, 0, 1, 0));
210 }
211
212 template<int align>
213 static inline void gmx_simdcall
214                    transposeScatterStoreU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
215 {
216     __m256  tv3;
217     __m128i mask = _mm_set_epi32(0, -1, -1, -1);
218
219     assert(std::size_t(offset) % 32 == 0);
220
221     avx256Transpose3By4InLanes(&v0.simdInternal_, &v1.simdInternal_, &v2.simdInternal_, &tv3);
222     _mm_maskstore_ps(base + align * offset[0], mask, _mm256_castps256_ps128(v0.simdInternal_));
223     _mm_maskstore_ps(base + align * offset[1], mask, _mm256_castps256_ps128(v1.simdInternal_));
224     _mm_maskstore_ps(base + align * offset[2], mask, _mm256_castps256_ps128(v2.simdInternal_));
225     _mm_maskstore_ps(base + align * offset[3], mask, _mm256_castps256_ps128(tv3));
226     _mm_maskstore_ps(base + align * offset[4], mask, _mm256_extractf128_ps(v0.simdInternal_, 0x1));
227     _mm_maskstore_ps(base + align * offset[5], mask, _mm256_extractf128_ps(v1.simdInternal_, 0x1));
228     _mm_maskstore_ps(base + align * offset[6], mask, _mm256_extractf128_ps(v2.simdInternal_, 0x1));
229     _mm_maskstore_ps(base + align * offset[7], mask, _mm256_extractf128_ps(tv3, 0x1));
230 }
231
232 template<int align>
233 static inline void gmx_simdcall
234                    transposeScatterIncrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
235 {
236     __m256 t1, t2, t3, t4, t5, t6, t7, t8, t9, t10;
237     __m128 tA, tB, tC, tD, tE, tF, tG, tH, tX;
238
239     if (align < 4)
240     {
241         t5  = _mm256_unpacklo_ps(v1.simdInternal_, v2.simdInternal_);
242         t6  = _mm256_unpackhi_ps(v1.simdInternal_, v2.simdInternal_);
243         t7  = _mm256_shuffle_ps(v0.simdInternal_, t5, _MM_SHUFFLE(1, 0, 0, 0));
244         t8  = _mm256_shuffle_ps(v0.simdInternal_, t5, _MM_SHUFFLE(3, 2, 0, 1));
245         t9  = _mm256_shuffle_ps(v0.simdInternal_, t6, _MM_SHUFFLE(1, 0, 0, 2));
246         t10 = _mm256_shuffle_ps(v0.simdInternal_, t6, _MM_SHUFFLE(3, 2, 0, 3));
247
248         tA = _mm256_castps256_ps128(t7);
249         tB = _mm256_castps256_ps128(t8);
250         tC = _mm256_castps256_ps128(t9);
251         tD = _mm256_castps256_ps128(t10);
252         tE = _mm256_extractf128_ps(t7, 0x1);
253         tF = _mm256_extractf128_ps(t8, 0x1);
254         tG = _mm256_extractf128_ps(t9, 0x1);
255         tH = _mm256_extractf128_ps(t10, 0x1);
256
257         tX = _mm_load_ss(base + align * offset[0]);
258         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[0] + 1));
259         tX = _mm_add_ps(tX, tA);
260         _mm_store_ss(base + align * offset[0], tX);
261         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[0] + 1), tX);
262
263         tX = _mm_load_ss(base + align * offset[1]);
264         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[1] + 1));
265         tX = _mm_add_ps(tX, tB);
266         _mm_store_ss(base + align * offset[1], tX);
267         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[1] + 1), tX);
268
269         tX = _mm_load_ss(base + align * offset[2]);
270         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[2] + 1));
271         tX = _mm_add_ps(tX, tC);
272         _mm_store_ss(base + align * offset[2], tX);
273         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[2] + 1), tX);
274
275         tX = _mm_load_ss(base + align * offset[3]);
276         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[3] + 1));
277         tX = _mm_add_ps(tX, tD);
278         _mm_store_ss(base + align * offset[3], tX);
279         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[3] + 1), tX);
280
281         tX = _mm_load_ss(base + align * offset[4]);
282         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[4] + 1));
283         tX = _mm_add_ps(tX, tE);
284         _mm_store_ss(base + align * offset[4], tX);
285         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[4] + 1), tX);
286
287         tX = _mm_load_ss(base + align * offset[5]);
288         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[5] + 1));
289         tX = _mm_add_ps(tX, tF);
290         _mm_store_ss(base + align * offset[5], tX);
291         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[5] + 1), tX);
292
293         tX = _mm_load_ss(base + align * offset[6]);
294         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[6] + 1));
295         tX = _mm_add_ps(tX, tG);
296         _mm_store_ss(base + align * offset[6], tX);
297         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[6] + 1), tX);
298
299         tX = _mm_load_ss(base + align * offset[7]);
300         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[7] + 1));
301         tX = _mm_add_ps(tX, tH);
302         _mm_store_ss(base + align * offset[7], tX);
303         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[7] + 1), tX);
304     }
305     else
306     {
307         // Extra elements means we can use full width-4 load/store operations
308         t1 = _mm256_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
309         t2 = _mm256_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
310         t3 = _mm256_unpacklo_ps(v1.simdInternal_, _mm256_setzero_ps());
311         t4 = _mm256_unpackhi_ps(v1.simdInternal_, _mm256_setzero_ps());
312         t5 = _mm256_unpacklo_ps(t1, t3); // x0 y0 z0  0 | x4 y4 z4 0
313         t6 = _mm256_unpackhi_ps(t1, t3); // x1 y1 z1  0 | x5 y5 z5 0
314         t7 = _mm256_unpacklo_ps(t2, t4); // x2 y2 z2  0 | x6 y6 z6 0
315         t8 = _mm256_unpackhi_ps(t2, t4); // x3 y3 z3  0 | x7 y7 z7 0
316
317         if (align % 4 == 0)
318         {
319             // We can use aligned load & store
320             _mm_store_ps(base + align * offset[0],
321                          _mm_add_ps(_mm_load_ps(base + align * offset[0]), _mm256_castps256_ps128(t5)));
322             _mm_store_ps(base + align * offset[1],
323                          _mm_add_ps(_mm_load_ps(base + align * offset[1]), _mm256_castps256_ps128(t6)));
324             _mm_store_ps(base + align * offset[2],
325                          _mm_add_ps(_mm_load_ps(base + align * offset[2]), _mm256_castps256_ps128(t7)));
326             _mm_store_ps(base + align * offset[3],
327                          _mm_add_ps(_mm_load_ps(base + align * offset[3]), _mm256_castps256_ps128(t8)));
328             _mm_store_ps(base + align * offset[4], _mm_add_ps(_mm_load_ps(base + align * offset[4]),
329                                                               _mm256_extractf128_ps(t5, 0x1)));
330             _mm_store_ps(base + align * offset[5], _mm_add_ps(_mm_load_ps(base + align * offset[5]),
331                                                               _mm256_extractf128_ps(t6, 0x1)));
332             _mm_store_ps(base + align * offset[6], _mm_add_ps(_mm_load_ps(base + align * offset[6]),
333                                                               _mm256_extractf128_ps(t7, 0x1)));
334             _mm_store_ps(base + align * offset[7], _mm_add_ps(_mm_load_ps(base + align * offset[7]),
335                                                               _mm256_extractf128_ps(t8, 0x1)));
336         }
337         else
338         {
339             // alignment >=5, but not a multiple of 4
340             _mm_storeu_ps(base + align * offset[0], _mm_add_ps(_mm_loadu_ps(base + align * offset[0]),
341                                                                _mm256_castps256_ps128(t5)));
342             _mm_storeu_ps(base + align * offset[1], _mm_add_ps(_mm_loadu_ps(base + align * offset[1]),
343                                                                _mm256_castps256_ps128(t6)));
344             _mm_storeu_ps(base + align * offset[2], _mm_add_ps(_mm_loadu_ps(base + align * offset[2]),
345                                                                _mm256_castps256_ps128(t7)));
346             _mm_storeu_ps(base + align * offset[3], _mm_add_ps(_mm_loadu_ps(base + align * offset[3]),
347                                                                _mm256_castps256_ps128(t8)));
348             _mm_storeu_ps(base + align * offset[4], _mm_add_ps(_mm_loadu_ps(base + align * offset[4]),
349                                                                _mm256_extractf128_ps(t5, 0x1)));
350             _mm_storeu_ps(base + align * offset[5], _mm_add_ps(_mm_loadu_ps(base + align * offset[5]),
351                                                                _mm256_extractf128_ps(t6, 0x1)));
352             _mm_storeu_ps(base + align * offset[6], _mm_add_ps(_mm_loadu_ps(base + align * offset[6]),
353                                                                _mm256_extractf128_ps(t7, 0x1)));
354             _mm_storeu_ps(base + align * offset[7], _mm_add_ps(_mm_loadu_ps(base + align * offset[7]),
355                                                                _mm256_extractf128_ps(t8, 0x1)));
356         }
357     }
358 }
359
360 template<int align>
361 static inline void gmx_simdcall
362                    transposeScatterDecrU(float* base, const std::int32_t offset[], SimdFloat v0, SimdFloat v1, SimdFloat v2)
363 {
364     __m256 t1, t2, t3, t4, t5, t6, t7, t8, t9, t10;
365     __m128 tA, tB, tC, tD, tE, tF, tG, tH, tX;
366
367     if (align < 4)
368     {
369         t5  = _mm256_unpacklo_ps(v1.simdInternal_, v2.simdInternal_);
370         t6  = _mm256_unpackhi_ps(v1.simdInternal_, v2.simdInternal_);
371         t7  = _mm256_shuffle_ps(v0.simdInternal_, t5, _MM_SHUFFLE(1, 0, 0, 0));
372         t8  = _mm256_shuffle_ps(v0.simdInternal_, t5, _MM_SHUFFLE(3, 2, 0, 1));
373         t9  = _mm256_shuffle_ps(v0.simdInternal_, t6, _MM_SHUFFLE(1, 0, 0, 2));
374         t10 = _mm256_shuffle_ps(v0.simdInternal_, t6, _MM_SHUFFLE(3, 2, 0, 3));
375
376         tA = _mm256_castps256_ps128(t7);
377         tB = _mm256_castps256_ps128(t8);
378         tC = _mm256_castps256_ps128(t9);
379         tD = _mm256_castps256_ps128(t10);
380         tE = _mm256_extractf128_ps(t7, 0x1);
381         tF = _mm256_extractf128_ps(t8, 0x1);
382         tG = _mm256_extractf128_ps(t9, 0x1);
383         tH = _mm256_extractf128_ps(t10, 0x1);
384
385         tX = _mm_load_ss(base + align * offset[0]);
386         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[0] + 1));
387         tX = _mm_sub_ps(tX, tA);
388         _mm_store_ss(base + align * offset[0], tX);
389         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[0] + 1), tX);
390
391         tX = _mm_load_ss(base + align * offset[1]);
392         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[1] + 1));
393         tX = _mm_sub_ps(tX, tB);
394         _mm_store_ss(base + align * offset[1], tX);
395         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[1] + 1), tX);
396
397         tX = _mm_load_ss(base + align * offset[2]);
398         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[2] + 1));
399         tX = _mm_sub_ps(tX, tC);
400         _mm_store_ss(base + align * offset[2], tX);
401         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[2] + 1), tX);
402
403         tX = _mm_load_ss(base + align * offset[3]);
404         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[3] + 1));
405         tX = _mm_sub_ps(tX, tD);
406         _mm_store_ss(base + align * offset[3], tX);
407         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[3] + 1), tX);
408
409         tX = _mm_load_ss(base + align * offset[4]);
410         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[4] + 1));
411         tX = _mm_sub_ps(tX, tE);
412         _mm_store_ss(base + align * offset[4], tX);
413         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[4] + 1), tX);
414
415         tX = _mm_load_ss(base + align * offset[5]);
416         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[5] + 1));
417         tX = _mm_sub_ps(tX, tF);
418         _mm_store_ss(base + align * offset[5], tX);
419         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[5] + 1), tX);
420
421         tX = _mm_load_ss(base + align * offset[6]);
422         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[6] + 1));
423         tX = _mm_sub_ps(tX, tG);
424         _mm_store_ss(base + align * offset[6], tX);
425         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[6] + 1), tX);
426
427         tX = _mm_load_ss(base + align * offset[7]);
428         tX = _mm_loadh_pi(tX, reinterpret_cast<__m64*>(base + align * offset[7] + 1));
429         tX = _mm_sub_ps(tX, tH);
430         _mm_store_ss(base + align * offset[7], tX);
431         _mm_storeh_pi(reinterpret_cast<__m64*>(base + align * offset[7] + 1), tX);
432     }
433     else
434     {
435         // Extra elements means we can use full width-4 load/store operations
436         t1 = _mm256_unpacklo_ps(v0.simdInternal_, v2.simdInternal_);
437         t2 = _mm256_unpackhi_ps(v0.simdInternal_, v2.simdInternal_);
438         t3 = _mm256_unpacklo_ps(v1.simdInternal_, _mm256_setzero_ps());
439         t4 = _mm256_unpackhi_ps(v1.simdInternal_, _mm256_setzero_ps());
440         t5 = _mm256_unpacklo_ps(t1, t3); // x0 y0 z0  0 | x4 y4 z4 0
441         t6 = _mm256_unpackhi_ps(t1, t3); // x1 y1 z1  0 | x5 y5 z5 0
442         t7 = _mm256_unpacklo_ps(t2, t4); // x2 y2 z2  0 | x6 y6 z6 0
443         t8 = _mm256_unpackhi_ps(t2, t4); // x3 y3 z3  0 | x7 y7 z7 0
444
445         if (align % 4 == 0)
446         {
447             // We can use aligned load & store
448             _mm_store_ps(base + align * offset[0],
449                          _mm_sub_ps(_mm_load_ps(base + align * offset[0]), _mm256_castps256_ps128(t5)));
450             _mm_store_ps(base + align * offset[1],
451                          _mm_sub_ps(_mm_load_ps(base + align * offset[1]), _mm256_castps256_ps128(t6)));
452             _mm_store_ps(base + align * offset[2],
453                          _mm_sub_ps(_mm_load_ps(base + align * offset[2]), _mm256_castps256_ps128(t7)));
454             _mm_store_ps(base + align * offset[3],
455                          _mm_sub_ps(_mm_load_ps(base + align * offset[3]), _mm256_castps256_ps128(t8)));
456             _mm_store_ps(base + align * offset[4], _mm_sub_ps(_mm_load_ps(base + align * offset[4]),
457                                                               _mm256_extractf128_ps(t5, 0x1)));
458             _mm_store_ps(base + align * offset[5], _mm_sub_ps(_mm_load_ps(base + align * offset[5]),
459                                                               _mm256_extractf128_ps(t6, 0x1)));
460             _mm_store_ps(base + align * offset[6], _mm_sub_ps(_mm_load_ps(base + align * offset[6]),
461                                                               _mm256_extractf128_ps(t7, 0x1)));
462             _mm_store_ps(base + align * offset[7], _mm_sub_ps(_mm_load_ps(base + align * offset[7]),
463                                                               _mm256_extractf128_ps(t8, 0x1)));
464         }
465         else
466         {
467             // alignment >=5, but not a multiple of 4
468             _mm_storeu_ps(base + align * offset[0], _mm_sub_ps(_mm_loadu_ps(base + align * offset[0]),
469                                                                _mm256_castps256_ps128(t5)));
470             _mm_storeu_ps(base + align * offset[1], _mm_sub_ps(_mm_loadu_ps(base + align * offset[1]),
471                                                                _mm256_castps256_ps128(t6)));
472             _mm_storeu_ps(base + align * offset[2], _mm_sub_ps(_mm_loadu_ps(base + align * offset[2]),
473                                                                _mm256_castps256_ps128(t7)));
474             _mm_storeu_ps(base + align * offset[3], _mm_sub_ps(_mm_loadu_ps(base + align * offset[3]),
475                                                                _mm256_castps256_ps128(t8)));
476             _mm_storeu_ps(base + align * offset[4], _mm_sub_ps(_mm_loadu_ps(base + align * offset[4]),
477                                                                _mm256_extractf128_ps(t5, 0x1)));
478             _mm_storeu_ps(base + align * offset[5], _mm_sub_ps(_mm_loadu_ps(base + align * offset[5]),
479                                                                _mm256_extractf128_ps(t6, 0x1)));
480             _mm_storeu_ps(base + align * offset[6], _mm_sub_ps(_mm_loadu_ps(base + align * offset[6]),
481                                                                _mm256_extractf128_ps(t7, 0x1)));
482             _mm_storeu_ps(base + align * offset[7], _mm_sub_ps(_mm_loadu_ps(base + align * offset[7]),
483                                                                _mm256_extractf128_ps(t8, 0x1)));
484         }
485     }
486 }
487
488 static inline void gmx_simdcall expandScalarsToTriplets(SimdFloat  scalar,
489                                                         SimdFloat* triplets0,
490                                                         SimdFloat* triplets1,
491                                                         SimdFloat* triplets2)
492 {
493     __m256 t0 = _mm256_permute2f128_ps(scalar.simdInternal_, scalar.simdInternal_, 0x21);
494     __m256 t1 = _mm256_permute_ps(scalar.simdInternal_, _MM_SHUFFLE(1, 0, 0, 0));
495     __m256 t2 = _mm256_permute_ps(t0, _MM_SHUFFLE(2, 2, 1, 1));
496     __m256 t3 = _mm256_permute_ps(scalar.simdInternal_, _MM_SHUFFLE(3, 3, 3, 2));
497     triplets0->simdInternal_ = _mm256_blend_ps(t1, t2, 0xF0);
498     triplets1->simdInternal_ = _mm256_blend_ps(t3, t1, 0xF0);
499     triplets2->simdInternal_ = _mm256_blend_ps(t2, t3, 0xF0);
500 }
501
502 template<int align>
503 static inline void gmx_simdcall gatherLoadBySimdIntTranspose(const float* base,
504                                                              SimdFInt32   simdoffset,
505                                                              SimdFloat*   v0,
506                                                              SimdFloat*   v1,
507                                                              SimdFloat*   v2,
508                                                              SimdFloat*   v3)
509 {
510     alignas(GMX_SIMD_ALIGNMENT) std::int32_t offset[GMX_SIMD_FLOAT_WIDTH];
511     _mm256_store_si256(reinterpret_cast<__m256i*>(offset), simdoffset.simdInternal_);
512     gatherLoadTranspose<align>(base, offset, v0, v1, v2, v3);
513 }
514
515 template<int align>
516 static inline void gmx_simdcall
517                    gatherLoadBySimdIntTranspose(const float* base, SimdFInt32 simdoffset, SimdFloat* v0, SimdFloat* v1)
518 {
519     alignas(GMX_SIMD_ALIGNMENT) std::int32_t offset[GMX_SIMD_FLOAT_WIDTH];
520     _mm256_store_si256(reinterpret_cast<__m256i*>(offset), simdoffset.simdInternal_);
521     gatherLoadTranspose<align>(base, offset, v0, v1);
522 }
523
524
525 template<int align>
526 static inline void gmx_simdcall
527                    gatherLoadUBySimdIntTranspose(const float* base, SimdFInt32 simdoffset, SimdFloat* v0, SimdFloat* v1)
528 {
529     __m128 t1, t2, t3, t4, t5, t6, t7, t8;
530     __m256 tA, tB, tC, tD;
531
532     alignas(GMX_SIMD_ALIGNMENT) std::int32_t offset[GMX_SIMD_FLOAT_WIDTH];
533     _mm256_store_si256(reinterpret_cast<__m256i*>(offset), simdoffset.simdInternal_);
534
535     t1 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[0]));
536     t2 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[1]));
537     t3 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[2]));
538     t4 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[3]));
539     t5 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[4]));
540     t6 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[5]));
541     t7 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[6]));
542     t8 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base + align * offset[7]));
543
544     tA = _mm256_insertf128_ps(_mm256_castps128_ps256(t1), t5, 0x1);
545     tB = _mm256_insertf128_ps(_mm256_castps128_ps256(t2), t6, 0x1);
546     tC = _mm256_insertf128_ps(_mm256_castps128_ps256(t3), t7, 0x1);
547     tD = _mm256_insertf128_ps(_mm256_castps128_ps256(t4), t8, 0x1);
548
549     tA                = _mm256_unpacklo_ps(tA, tC);
550     tB                = _mm256_unpacklo_ps(tB, tD);
551     v0->simdInternal_ = _mm256_unpacklo_ps(tA, tB);
552     v1->simdInternal_ = _mm256_unpackhi_ps(tA, tB);
553 }
554
555 static inline float gmx_simdcall reduceIncr4ReturnSum(float* m, SimdFloat v0, SimdFloat v1, SimdFloat v2, SimdFloat v3)
556 {
557     __m128 t0, t2;
558
559     assert(std::size_t(m) % 16 == 0);
560
561     v0.simdInternal_ = _mm256_hadd_ps(v0.simdInternal_, v1.simdInternal_);
562     v2.simdInternal_ = _mm256_hadd_ps(v2.simdInternal_, v3.simdInternal_);
563     v0.simdInternal_ = _mm256_hadd_ps(v0.simdInternal_, v2.simdInternal_);
564     t0               = _mm_add_ps(_mm256_castps256_ps128(v0.simdInternal_), _mm256_extractf128_ps(v0.simdInternal_, 0x1));
565
566     t2 = _mm_add_ps(t0, _mm_load_ps(m));
567     _mm_store_ps(m, t2);
568
569     t0 = _mm_add_ps(t0, _mm_permute_ps(t0, _MM_SHUFFLE(1, 0, 3, 2)));
570     t0 = _mm_add_ss(t0, _mm_permute_ps(t0, _MM_SHUFFLE(0, 3, 2, 1)));
571     return *reinterpret_cast<float*>(&t0);
572 }
573
574
575 /*************************************
576  * Half-simd-width utility functions *
577  *************************************/
578 static inline SimdFloat gmx_simdcall loadDualHsimd(const float* m0, const float* m1)
579 {
580     assert(std::size_t(m0) % 16 == 0);
581     assert(std::size_t(m1) % 16 == 0);
582
583     return { _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(m0)), _mm_load_ps(m1), 0x1) };
584 }
585
586 static inline SimdFloat gmx_simdcall loadDuplicateHsimd(const float* m)
587 {
588     assert(std::size_t(m) % 16 == 0);
589
590     return { _mm256_broadcast_ps(reinterpret_cast<const __m128*>(m)) };
591 }
592
593 static inline SimdFloat gmx_simdcall loadU1DualHsimd(const float* m)
594 {
595     __m128 t0, t1;
596     t0 = _mm_broadcast_ss(m);
597     t1 = _mm_broadcast_ss(m + 1);
598     return { _mm256_insertf128_ps(_mm256_castps128_ps256(t0), t1, 0x1) };
599 }
600
601
602 static inline void gmx_simdcall storeDualHsimd(float* m0, float* m1, SimdFloat a)
603 {
604     assert(std::size_t(m0) % 16 == 0);
605     assert(std::size_t(m1) % 16 == 0);
606     _mm_store_ps(m0, _mm256_castps256_ps128(a.simdInternal_));
607     _mm_store_ps(m1, _mm256_extractf128_ps(a.simdInternal_, 0x1));
608 }
609
610 static inline void gmx_simdcall incrDualHsimd(float* m0, float* m1, SimdFloat a)
611 {
612     assert(std::size_t(m0) % 16 == 0);
613     assert(std::size_t(m1) % 16 == 0);
614     _mm_store_ps(m0, _mm_add_ps(_mm256_castps256_ps128(a.simdInternal_), _mm_load_ps(m0)));
615     _mm_store_ps(m1, _mm_add_ps(_mm256_extractf128_ps(a.simdInternal_, 0x1), _mm_load_ps(m1)));
616 }
617
618 static inline void gmx_simdcall decr3Hsimd(float* m, SimdFloat a0, SimdFloat a1, SimdFloat a2)
619 {
620     assert(std::size_t(m) % 16 == 0);
621     decrHsimd(m, a0);
622     decrHsimd(m + GMX_SIMD_FLOAT_WIDTH / 2, a1);
623     decrHsimd(m + GMX_SIMD_FLOAT_WIDTH, a2);
624 }
625
626
627 template<int align>
628 static inline void gmx_simdcall gatherLoadTransposeHsimd(const float*       base0,
629                                                          const float*       base1,
630                                                          const std::int32_t offset[],
631                                                          SimdFloat*         v0,
632                                                          SimdFloat*         v1)
633 {
634     __m128 t0, t1, t2, t3, t4, t5, t6, t7;
635     __m256 tA, tB, tC, tD;
636
637     assert(std::size_t(offset) % 16 == 0);
638     assert(std::size_t(base0) % 8 == 0);
639     assert(std::size_t(base1) % 8 == 0);
640     assert(align % 2 == 0);
641
642     t0 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base0 + align * offset[0]));
643     t1 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base0 + align * offset[1]));
644     t2 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base0 + align * offset[2]));
645     t3 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base0 + align * offset[3]));
646     t4 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base1 + align * offset[0]));
647     t5 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base1 + align * offset[1]));
648     t6 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base1 + align * offset[2]));
649     t7 = _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<const __m64*>(base1 + align * offset[3]));
650
651     tA = _mm256_insertf128_ps(_mm256_castps128_ps256(t0), t4, 0x1);
652     tB = _mm256_insertf128_ps(_mm256_castps128_ps256(t1), t5, 0x1);
653     tC = _mm256_insertf128_ps(_mm256_castps128_ps256(t2), t6, 0x1);
654     tD = _mm256_insertf128_ps(_mm256_castps128_ps256(t3), t7, 0x1);
655
656     tA                = _mm256_unpacklo_ps(tA, tC);
657     tB                = _mm256_unpacklo_ps(tB, tD);
658     v0->simdInternal_ = _mm256_unpacklo_ps(tA, tB);
659     v1->simdInternal_ = _mm256_unpackhi_ps(tA, tB);
660 }
661
662
663 static inline float gmx_simdcall reduceIncr4ReturnSumHsimd(float* m, SimdFloat v0, SimdFloat v1)
664 {
665     __m128 t0, t1;
666
667     v0.simdInternal_ = _mm256_hadd_ps(v0.simdInternal_, v1.simdInternal_);
668     t0               = _mm256_extractf128_ps(v0.simdInternal_, 0x1);
669     t0               = _mm_hadd_ps(_mm256_castps256_ps128(v0.simdInternal_), t0);
670     t0               = _mm_permute_ps(t0, _MM_SHUFFLE(3, 1, 2, 0));
671
672     assert(std::size_t(m) % 16 == 0);
673
674     t1 = _mm_add_ps(t0, _mm_load_ps(m));
675     _mm_store_ps(m, t1);
676
677     t0 = _mm_add_ps(t0, _mm_permute_ps(t0, _MM_SHUFFLE(1, 0, 3, 2)));
678     t0 = _mm_add_ss(t0, _mm_permute_ps(t0, _MM_SHUFFLE(0, 3, 2, 1)));
679     return *reinterpret_cast<float*>(&t0);
680 }
681
682 static inline SimdFloat gmx_simdcall loadU4NOffset(const float* m, int offset)
683 {
684     return { _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(m)), _mm_loadu_ps(m + offset), 0x1) };
685 }
686
687
688 } // namespace gmx
689
690 #endif // GMX_SIMD_IMPL_X86_AVX_256_UTIL_FLOAT_H