8508941b56efa8fca266f18bc402325094d7c8f5
[alexxy/gromacs.git] / src / gromacs / simd / impl_intel_mic / impl_intel_mic.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 #ifndef GMX_SIMD_IMPL_INTEL_MIC_H
37 #define GMX_SIMD_IMPL_INTEL_MIC_H
38
39 #include <math.h>
40 #include <immintrin.h>
41
42 #include "config.h"
43
44 /* Intel Xeon Phi, or
45  * the-artist-formerly-known-as-Knight's-corner, or
46  * the-artist-formerly-formerly-known-as-MIC, or
47  * the artist formerly-formerly-formerly-known-as-Larrabee
48  * 512-bit SIMD instruction wrappers.
49  */
50
51 /* Capability definitions for Xeon Phi SIMD */
52 #define GMX_SIMD_HAVE_FLOAT
53 #define GMX_SIMD_HAVE_DOUBLE
54 #define GMX_SIMD_HAVE_SIMD_HARDWARE
55 #define GMX_SIMD_HAVE_LOADU
56 #define GMX_SIMD_HAVE_STOREU
57 #define GMX_SIMD_HAVE_LOGICAL
58 #define GMX_SIMD_HAVE_FMA
59 #undef  GMX_SIMD_HAVE_FRACTION
60 #define GMX_SIMD_HAVE_FINT32
61 #define  GMX_SIMD_HAVE_FINT32_EXTRACT
62 #define GMX_SIMD_HAVE_FINT32_LOGICAL
63 #define GMX_SIMD_HAVE_FINT32_ARITHMETICS
64 #define GMX_SIMD_HAVE_DINT32
65 #define  GMX_SIMD_HAVE_DINT32_EXTRACT
66 #define GMX_SIMD_HAVE_DINT32_LOGICAL
67 #define GMX_SIMD_HAVE_DINT32_ARITHMETICS
68 #define GMX_SIMD4_HAVE_FLOAT
69 #define GMX_SIMD4_HAVE_DOUBLE
70
71 /* Implementation details */
72 #define GMX_SIMD_FLOAT_WIDTH        16
73 #define GMX_SIMD_DOUBLE_WIDTH        8
74 #define GMX_SIMD_FINT32_WIDTH       16
75 #define GMX_SIMD_DINT32_WIDTH        8
76 #define GMX_SIMD_RSQRT_BITS         23
77 #define GMX_SIMD_RCP_BITS           23
78
79 /****************************************************
80  *      SINGLE PRECISION SIMD IMPLEMENTATION        *
81  ****************************************************/
82 #define gmx_simd_float_t           __m512
83 #define gmx_simd_load_f            _mm512_load_ps
84 #define gmx_simd_load1_f(m)        _mm512_extload_ps(m, _MM_UPCONV_PS_NONE, _MM_BROADCAST_1X16, _MM_HINT_NONE)
85 #define gmx_simd_set1_f            _mm512_set1_ps
86 #define gmx_simd_store_f           _mm512_store_ps
87 #define gmx_simd_loadu_f           gmx_simd_loadu_f_mic
88 #define gmx_simd_storeu_f          gmx_simd_storeu_f_mic
89 #define gmx_simd_setzero_f         _mm512_setzero_ps
90 #define gmx_simd_add_f             _mm512_add_ps
91 #define gmx_simd_sub_f             _mm512_sub_ps
92 #define gmx_simd_mul_f             _mm512_mul_ps
93 #define gmx_simd_fmadd_f           _mm512_fmadd_ps
94 #define gmx_simd_fmsub_f           _mm512_fmsub_ps
95 #define gmx_simd_fnmadd_f          _mm512_fnmadd_ps
96 #define gmx_simd_fnmsub_f          _mm512_fnmsub_ps
97 #define gmx_simd_and_f(a, b)        _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(a), _mm512_castps_si512(b)))
98 #define gmx_simd_andnot_f(a, b)     _mm512_castsi512_ps(_mm512_andnot_epi32(_mm512_castps_si512(a), _mm512_castps_si512(b)))
99 #define gmx_simd_or_f(a, b)         _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(a), _mm512_castps_si512(b)))
100 #define gmx_simd_xor_f(a, b)        _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(a), _mm512_castps_si512(b)))
101 #define gmx_simd_rsqrt_f           _mm512_rsqrt23_ps
102 #define gmx_simd_rcp_f             _mm512_rcp23_ps
103 #define gmx_simd_fabs_f(x)         gmx_simd_andnot_f(_mm512_set1_ps(GMX_FLOAT_NEGZERO), x)
104 #define gmx_simd_fneg_f(x)         _mm512_addn_ps(x, _mm512_setzero_ps())
105 #define gmx_simd_max_f             _mm512_gmax_ps
106 #define gmx_simd_min_f             _mm512_gmin_ps
107 #define gmx_simd_round_f(x)        _mm512_round_ps(x, _MM_FROUND_TO_NEAREST_INT, _MM_EXPADJ_NONE)
108 #define gmx_simd_trunc_f(x)        _mm512_round_ps(x, _MM_FROUND_TO_ZERO, _MM_EXPADJ_NONE)
109 #define gmx_simd_fraction_f(x)     _mm512_sub_ps(x, gmx_simd_trunc_f(x))
110 #define gmx_simd_get_exponent_f(x) _mm512_getexp_ps(x)
111 #define gmx_simd_get_mantissa_f(x) _mm512_getmant_ps(x, _MM_MANT_NORM_1_2, _MM_MANT_SIGN_zero)
112 #define gmx_simd_set_exponent_f(x) gmx_simd_set_exponent_f_mic(x)
113 /* integer datatype corresponding to float: gmx_simd_fint32_t */
114 #define gmx_simd_fint32_t          __m512i
115 #define gmx_simd_load_fi           _mm512_load_epi32
116 #define gmx_simd_set1_fi           _mm512_set1_epi32
117 #define gmx_simd_store_fi          _mm512_store_epi32
118 #define gmx_simd_loadu_fi          gmx_simd_loadu_fi_mic
119 #define gmx_simd_storeu_fi         gmx_simd_storeu_fi_mic
120 #define gmx_simd_extract_fi        gmx_simd_extract_fi_mic
121 #define gmx_simd_setzero_fi        _mm512_setzero_epi32
122 #define gmx_simd_cvt_f2i(a)        _mm512_cvtfxpnt_round_adjustps_epi32(a, _MM_FROUND_TO_NEAREST_INT, _MM_EXPADJ_NONE)
123 #define gmx_simd_cvtt_f2i(a)       _mm512_cvtfxpnt_round_adjustps_epi32(a, _MM_FROUND_TO_ZERO, _MM_EXPADJ_NONE)
124 #define gmx_simd_cvt_i2f(a)        _mm512_cvtfxpnt_round_adjustepi32_ps(a, _MM_FROUND_TO_NEAREST_INT, _MM_EXPADJ_NONE)
125 /* Integer logical ops on gmx_simd_fint32_t */
126 #define gmx_simd_slli_fi           _mm512_slli_epi32
127 #define gmx_simd_srli_fi           _mm512_srli_epi32
128 #define gmx_simd_and_fi            _mm512_and_epi32
129 #define gmx_simd_andnot_fi         _mm512_andnot_epi32
130 #define gmx_simd_or_fi             _mm512_or_epi32
131 #define gmx_simd_xor_fi            _mm512_xor_epi32
132 /* Integer arithmetic ops on gmx_simd_fint32_t */
133 #define gmx_simd_add_fi            _mm512_add_epi32
134 #define gmx_simd_sub_fi            _mm512_sub_epi32
135 #define gmx_simd_mul_fi            _mm512_mullo_epi32
136 /* Boolean & comparison operations on gmx_simd_float_t */
137 #define gmx_simd_fbool_t           __mmask16
138 #define gmx_simd_cmpeq_f(a, b)     _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ)
139 #define gmx_simd_cmplt_f(a, b)     _mm512_cmp_ps_mask(a, b, _CMP_LT_OS)
140 #define gmx_simd_cmple_f(a, b)     _mm512_cmp_ps_mask(a, b, _CMP_LE_OS)
141 #define gmx_simd_and_fb            _mm512_kand
142 #define gmx_simd_andnot_fb(a, b)   _mm512_knot(_mm512_kor(a, b))
143 #define gmx_simd_or_fb             _mm512_kor
144 #define gmx_simd_anytrue_fb        _mm512_mask2int
145 #define gmx_simd_blendzero_f(a, sel)    _mm512_mask_mov_ps(_mm512_setzero_ps(), sel, a)
146 #define gmx_simd_blendnotzero_f(a, sel) _mm512_mask_mov_ps(_mm512_setzero_ps(), _mm512_knot(sel), a)
147 #define gmx_simd_blendv_f(a, b, sel)    _mm512_mask_blend_ps(sel, a, b)
148 #define gmx_simd_reduce_f(a)       _mm512_reduce_add_ps(a)
149 /* Boolean & comparison operations on gmx_simd_fint32_t */
150 #define gmx_simd_fibool_t          __mmask16
151 #define gmx_simd_cmpeq_fi(a, b)    _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_EQ)
152 #define gmx_simd_cmplt_fi(a, b)    _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LT)
153 #define gmx_simd_and_fib           _mm512_kand
154 #define gmx_simd_or_fib            _mm512_kor
155 #define gmx_simd_anytrue_fib       _mm512_mask2int
156 #define gmx_simd_blendzero_fi(a, sel)    _mm512_mask_mov_epi32(_mm512_setzero_epi32(), sel, a)
157 #define gmx_simd_blendnotzero_fi(a, sel) _mm512_mask_mov_epi32(_mm512_setzero_epi32(), _mm512_knot(sel), a)
158 #define gmx_simd_blendv_fi(a, b, sel)    _mm512_mask_blend_epi32(sel, a, b)
159 /* Conversions between different booleans */
160 #define gmx_simd_cvt_fb2fib(x)     (x)
161 #define gmx_simd_cvt_fib2fb(x)     (x)
162
163 /* MIC provides full single precision of some neat functions: */
164 /* 1/sqrt(x) and 1/x work fine in simd_math.h, and won't use extra iterations */
165
166 #define gmx_simd_exp2_f            gmx_simd_exp2_f_mic
167 #define gmx_simd_exp_f             gmx_simd_exp_f_mic
168 #define gmx_simd_log_f             gmx_simd_log_f_mic
169
170 /****************************************************
171  *      DOUBLE PRECISION SIMD IMPLEMENTATION        *
172  ****************************************************/
173 #define gmx_simd_double_t          __m512d
174 #define gmx_simd_load_d            _mm512_load_pd
175 #define gmx_simd_load1_d(m)        _mm512_extload_pd(m, _MM_UPCONV_PD_NONE, _MM_BROADCAST_1X8, _MM_HINT_NONE)
176 #define gmx_simd_set1_d            _mm512_set1_pd
177 #define gmx_simd_store_d           _mm512_store_pd
178 #define gmx_simd_loadu_d           gmx_simd_loadu_d_mic
179 #define gmx_simd_storeu_d          gmx_simd_storeu_d_mic
180 #define gmx_simd_setzero_d         _mm512_setzero_pd
181 #define gmx_simd_add_d             _mm512_add_pd
182 #define gmx_simd_sub_d             _mm512_sub_pd
183 #define gmx_simd_mul_d             _mm512_mul_pd
184 #define gmx_simd_fmadd_d           _mm512_fmadd_pd
185 #define gmx_simd_fmsub_d           _mm512_fmsub_pd
186 #define gmx_simd_fnmadd_d          _mm512_fnmadd_pd
187 #define gmx_simd_fnmsub_d          _mm512_fnmsub_pd
188 #define gmx_simd_and_d(a, b)       _mm512_castsi512_pd(_mm512_and_epi32(_mm512_castpd_si512(a), _mm512_castpd_si512(b)))
189 #define gmx_simd_andnot_d(a, b)    _mm512_castsi512_pd(_mm512_andnot_epi32(_mm512_castpd_si512(a), _mm512_castpd_si512(b)))
190 #define gmx_simd_or_d(a, b)        _mm512_castsi512_pd(_mm512_or_epi32(_mm512_castpd_si512(a), _mm512_castpd_si512(b)))
191 #define gmx_simd_xor_d(a, b)       _mm512_castsi512_pd(_mm512_xor_epi32(_mm512_castpd_si512(a), _mm512_castpd_si512(b)))
192 #define gmx_simd_rsqrt_d(x)        _mm512_cvtpslo_pd(_mm512_rsqrt23_ps(_mm512_cvtpd_pslo(x)))
193 #define gmx_simd_rcp_d(x)          _mm512_cvtpslo_pd(_mm512_rcp23_ps(_mm512_cvtpd_pslo(x)))
194 #define gmx_simd_fabs_d(x)         gmx_simd_andnot_d(_mm512_set1_pd(GMX_DOUBLE_NEGZERO), x)
195 #define gmx_simd_fneg_d(x)         _mm512_addn_pd(x, _mm512_setzero_pd())
196 #define gmx_simd_max_d             _mm512_gmax_pd
197 #define gmx_simd_min_d             _mm512_gmin_pd
198 #define gmx_simd_round_d(a)        _mm512_roundfxpnt_adjust_pd(a, _MM_FROUND_TO_NEAREST_INT, _MM_EXPADJ_NONE)
199 #define gmx_simd_trunc_d(a)        _mm512_roundfxpnt_adjust_pd(a, _MM_FROUND_TO_ZERO, _MM_EXPADJ_NONE)
200 #define gmx_simd_fraction_d(x)     _mm512_sub_pd(x, gmx_simd_trunc_d(x))
201 #define gmx_simd_get_exponent_d(x) _mm512_getexp_pd(x)
202 #define gmx_simd_get_mantissa_d(x) _mm512_getmant_pd(x, _MM_MANT_NORM_1_2, _MM_MANT_SIGN_zero)
203 #define gmx_simd_set_exponent_d(x) gmx_simd_set_exponent_d_mic(x)
204 /* integer datatype corresponding to float: gmx_simd_fint32_t
205    Doesn't use mask other than where required. No side effect expected for operating on the (unused) upper 8.
206  */
207 #define gmx_simd_dint32_t          __m512i
208 #define gmx_simd_load_di(m)        _mm512_mask_loadunpacklo_epi32(_mm512_undefined_epi32(), mask_loh, m)
209 #define gmx_simd_set1_di           _mm512_set1_epi32
210 #define gmx_simd_store_di(m, a)    _mm512_mask_packstorelo_epi32(m, mask_loh, a)
211 #define gmx_simd_loadu_di          gmx_simd_loadu_di_mic
212 #define gmx_simd_storeu_di         gmx_simd_storeu_di_mic
213 #define gmx_simd_extract_di        gmx_simd_extract_di_mic
214 #define gmx_simd_setzero_di        _mm512_setzero_epi32
215 #define gmx_simd_cvt_d2i(a)        _mm512_cvtfxpnt_roundpd_epi32lo(a, _MM_FROUND_TO_NEAREST_INT)
216 #define gmx_simd_cvtt_d2i(a)       _mm512_cvtfxpnt_roundpd_epi32lo(a, _MM_FROUND_TO_ZERO)
217 #define gmx_simd_cvt_i2d           _mm512_cvtepi32lo_pd
218 /* Integer logical ops on gmx_simd_fint32_t */
219 #define gmx_simd_slli_di           _mm512_slli_epi32
220 #define gmx_simd_srli_di           _mm512_srli_epi32
221 #define gmx_simd_and_di            _mm512_and_epi32
222 #define gmx_simd_andnot_di         _mm512_andnot_epi32
223 #define gmx_simd_or_di             _mm512_or_epi32
224 #define gmx_simd_xor_di            _mm512_xor_epi32
225 /* Integer arithmetic ops on gmx_simd_fint32_t */
226 #define gmx_simd_add_di            _mm512_add_epi32
227 #define gmx_simd_sub_di            _mm512_sub_epi32
228 #define gmx_simd_mul_di            _mm512_mullo_epi32
229 /* Boolean & comparison operations on gmx_simd_float_t */
230 #define gmx_simd_dbool_t           __mmask8
231 #define gmx_simd_cmpeq_d(a, b)     _mm512_cmp_pd_mask(a, b, _CMP_EQ_OQ)
232 #define gmx_simd_cmplt_d(a, b)     _mm512_cmp_pd_mask(a, b, _CMP_LT_OS)
233 #define gmx_simd_cmple_d(a, b)     _mm512_cmp_pd_mask(a, b, _CMP_LE_OS)
234 #define gmx_simd_and_db            _mm512_kand
235 #define gmx_simd_or_db             _mm512_kor
236 #define gmx_simd_anytrue_db(x)     _mm512_mask2int(x)
237 #define gmx_simd_blendzero_d(a, sel)    _mm512_mask_mov_pd(_mm512_setzero_pd(), sel, a)
238 #define gmx_simd_blendnotzero_d(a, sel) _mm512_mask_mov_pd(_mm512_setzero_pd(), _mm512_knot(sel), a)
239 #define gmx_simd_blendv_d(a, b, sel)    _mm512_mask_blend_pd(sel, a, b)
240 #define gmx_simd_reduce_d(a)       _mm512_reduce_add_pd(a)
241 /* Boolean & comparison operations on gmx_simd_fint32_t */
242 #define gmx_simd_dibool_t          __mmask16
243 #define gmx_simd_cmpeq_di(a, b)    _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_EQ)
244 #define gmx_simd_cmplt_di(a, b)    _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LT)
245 #define gmx_simd_and_dib           _mm512_kand
246 #define gmx_simd_or_dib            _mm512_kor
247 #define gmx_simd_anytrue_dib(x)    (_mm512_mask2int(x)&0xFF)
248 #define gmx_simd_blendzero_di(a, sel)    _mm512_mask_mov_epi32(_mm512_setzero_epi32(), sel, a)
249 #define gmx_simd_blendnotzero_di(a, sel) _mm512_mask_mov_epi32(_mm512_setzero_epi32(), _mm512_knot(sel), a)
250 #define gmx_simd_blendv_di(a, b, sel)    _mm512_mask_blend_epi32(sel, a, b)
251 /* Conversions between booleans. Double & dint stuff is stored in low bits */
252 #define gmx_simd_cvt_db2dib(x)     (x)
253 #define gmx_simd_cvt_dib2db(x)     (x)
254
255 /* Float/double conversion */
256 #define gmx_simd_cvt_f2dd          gmx_simd_cvt_f2dd_mic
257 #define gmx_simd_cvt_dd2f          gmx_simd_cvt_dd2f_mic
258
259 /****************************************************
260  *      SINGLE PRECISION SIMD4 IMPLEMENTATION       *
261  ****************************************************/
262 /* Load and store are guranteed to only access the 4 floats. All arithmetic operations
263    only operate on the 4 elements (to avoid floating excpetions). But other operations
264    are not gurateed to not modify the other 12 elements. E.g. setzero or blendzero
265    set the upper 12 to zero. */
266 #define gmx_simd4_float_t           __m512
267 #define gmx_simd4_mask              _mm512_int2mask(0xF)
268 #define gmx_simd4_load_f(m)         _mm512_mask_loadunpacklo_ps(_mm512_undefined_ps(), gmx_simd4_mask, m)
269 #define gmx_simd4_load1_f(m)        _mm512_mask_extload_ps(_mm512_undefined_ps(), gmx_simd4_mask, m, _MM_UPCONV_PS_NONE, _MM_BROADCAST_1X16, _MM_HINT_NONE)
270 #define gmx_simd4_set1_f            _mm512_set1_ps
271 #define gmx_simd4_store_f(m, a)     _mm512_mask_packstorelo_ps(m, gmx_simd4_mask, a)
272 #define gmx_simd4_loadu_f           gmx_simd4_loadu_f_mic
273 #define gmx_simd4_storeu_f          gmx_simd4_storeu_f_mic
274 #define gmx_simd4_setzero_f         _mm512_setzero_ps
275 #define gmx_simd4_add_f(a, b)       _mm512_mask_add_ps(_mm512_undefined_ps(), gmx_simd4_mask, a, b)
276 #define gmx_simd4_sub_f(a, b)       _mm512_mask_sub_ps(_mm512_undefined_ps(), gmx_simd4_mask, a, b)
277 #define gmx_simd4_mul_f(a, b)       _mm512_mask_mul_ps(_mm512_undefined_ps(), gmx_simd4_mask, a, b)
278 #define gmx_simd4_fmadd_f(a, b, c)  _mm512_mask_fmadd_ps(a, gmx_simd4_mask, b, c)
279 #define gmx_simd4_fmsub_f(a, b, c)  _mm512_mask_fmsub_ps(a, gmx_simd4_mask, b, c)
280 #define gmx_simd4_fnmadd_f(a, b, c) _mm512_mask_fnmadd_ps(a, gmx_simd4_mask, b, c)
281 #define gmx_simd4_fnmsub_f(a, b, c) _mm512_mask_fnmsub_ps(a, gmx_simd4_mask, b, c)
282 #define gmx_simd4_and_f(a, b)       _mm512_castsi512_ps(_mm512_mask_and_epi32(_mm512_undefined_epi32(), gmx_simd4_mask, _mm512_castps_si512(a), _mm512_castps_si512(b)))
283 #define gmx_simd4_andnot_f(a, b)    _mm512_castsi512_ps(_mm512_mask_andnot_epi32(_mm512_undefined_epi32(), gmx_simd4_mask, _mm512_castps_si512(a), _mm512_castps_si512(b)))
284 #define gmx_simd4_or_f(a, b)        _mm512_castsi512_ps(_mm512_mask_or_epi32(_mm512_undefined_epi32(), gmx_simd4_mask, _mm512_castps_si512(a), _mm512_castps_si512(b)))
285 #define gmx_simd4_xor_f(a, b)       _mm512_castsi512_ps(_mm512_mask_xor_epi32(_mm512_undefined_epi32(), gmx_simd4_mask, _mm512_castps_si512(a), _mm512_castps_si512(b)))
286 #define gmx_simd4_rsqrt_f(a)        _mm512_mask_rsqrt23_ps(_mm512_undefined_ps(), gmx_simd4_mask, a)
287 #define gmx_simd4_fabs_f(x)         gmx_simd4_andnot_f(_mm512_set1_ps(GMX_FLOAT_NEGZERO), x)
288 #define gmx_simd4_fneg_f(x)         _mm512_mask_addn_ps(_mm512_undefined_ps(), gmx_simd4_mask, x, _mm512_setzero_ps())
289 #define gmx_simd4_max_f(a, b)       _mm512_mask_gmax_ps(_mm512_undefined_ps(), gmx_simd4_mask, a, b)
290 #define gmx_simd4_min_f(a, b)       _mm512_mask_gmin_ps(_mm512_undefined_ps(), gmx_simd4_mask, a, b)
291 #define gmx_simd4_round_f(x)        _mm512_mask_round_ps(_mm512_undefined_ps(), gmx_simd4_mask, x, _MM_FROUND_TO_NEAREST_INT, _MM_EXPADJ_NONE)
292 #define gmx_simd4_trunc_f(x)        _mm512_mask_round_ps(_mm512_undefined_ps(), gmx_simd4_mask, x, _MM_FROUND_TO_ZERO, _MM_EXPADJ_NONE)
293 #define gmx_simd4_dotproduct3_f(a, b) _mm512_mask_reduce_add_ps(_mm512_int2mask(7), _mm512_mask_mul_ps(_mm512_undefined_ps(), _mm512_int2mask(7), a, b))
294 #define gmx_simd4_fbool_t           __mmask16
295 #define gmx_simd4_cmpeq_f(a, b)     _mm512_mask_cmp_ps_mask(gmx_simd4_mask, a, b, _CMP_EQ_OQ)
296 #define gmx_simd4_cmplt_f(a, b)     _mm512_mask_cmp_ps_mask(gmx_simd4_mask, a, b, _CMP_LT_OS)
297 #define gmx_simd4_cmple_f(a, b)     _mm512_mask_cmp_ps_mask(gmx_simd4_mask, a, b, _CMP_LE_OS)
298 #define gmx_simd4_and_fb            _mm512_kand
299 #define gmx_simd4_or_fb             _mm512_kor
300 #define gmx_simd4_anytrue_fb(x)     (_mm512_mask2int(x)&0xF)
301 #define gmx_simd4_blendzero_f(a, sel)    _mm512_mask_mov_ps(_mm512_setzero_ps(), sel, a)
302 #define gmx_simd4_blendnotzero_f(a, sel) _mm512_mask_mov_ps(_mm512_setzero_ps(), _mm512_knot(sel), a)
303 #define gmx_simd4_blendv_f(a, b, sel)    _mm512_mask_blend_ps(sel, a, b)
304 #define gmx_simd4_reduce_f(x)       _mm512_mask_reduce_add_ps(_mm512_int2mask(0xF), x)
305
306 /****************************************************
307  *      DOUBLE PRECISION SIMD4 IMPLEMENTATION       *
308  ****************************************************/
309 #define gmx_simd4_double_t          __m512d
310 #define gmx_simd4_mask              _mm512_int2mask(0xF)
311 #define gmx_simd4_load_d(m)         _mm512_mask_loadunpacklo_pd(_mm512_undefined_pd(), gmx_simd4_mask, m)
312 #define gmx_simd4_load1_d(m)        _mm512_mask_extload_pd(_mm512_undefined_pd(), gmx_simd4_mask, m, _MM_UPCONV_PD_NONE, _MM_BROADCAST_1X8, _MM_HINT_NONE)
313 #define gmx_simd4_set1_d            _mm512_set1_pd
314 #define gmx_simd4_store_d(m, a)     _mm512_mask_packstorelo_pd(m, gmx_simd4_mask, a)
315 #define gmx_simd4_loadu_d           gmx_simd4_loadu_d_mic
316 #define gmx_simd4_storeu_d          gmx_simd4_storeu_d_mic
317 #define gmx_simd4_setzero_d         _mm512_setzero_pd
318 #define gmx_simd4_add_d(a, b)       _mm512_mask_add_pd(_mm512_undefined_pd(), gmx_simd4_mask, a, b)
319 #define gmx_simd4_sub_d(a, b)       _mm512_mask_sub_pd(_mm512_undefined_pd(), gmx_simd4_mask, a, b)
320 #define gmx_simd4_mul_d(a, b)       _mm512_mask_mul_pd(_mm512_undefined_pd(), gmx_simd4_mask, a, b)
321 #define gmx_simd4_fmadd_d(a, b, c)  _mm512_mask_fmadd_pd(a, gmx_simd4_mask, b, c)
322 #define gmx_simd4_fmsub_d(a, b, c)  _mm512_mask_fmsub_pd(a, gmx_simd4_mask, b, c)
323 #define gmx_simd4_fnmadd_d(a, b, c) _mm512_mask_fnmadd_pd(a, gmx_simd4_mask, b, c)
324 #define gmx_simd4_fnmsub_d(a, b, c) _mm512_mask_fnmsub_pd(a, gmx_simd4_mask, b, c)
325 #define gmx_simd4_and_d(a, b)       _mm512_castsi512_pd(_mm512_mask_and_epi32(_mm512_undefined_epi32(), mask_loh, _mm512_castpd_si512(a), _mm512_castpd_si512(b)))
326 #define gmx_simd4_andnot_d(a, b)    _mm512_castsi512_pd(_mm512_mask_andnot_epi32(_mm512_undefined_epi32(), mask_loh, _mm512_castpd_si512(a), _mm512_castpd_si512(b)))
327 #define gmx_simd4_or_d(a, b)        _mm512_castsi512_pd(_mm512_mask_or_epi32(_mm512_undefined_epi32(), mask_loh, _mm512_castpd_si512(a), _mm512_castpd_si512(b)))
328 #define gmx_simd4_xor_d(a, b)       _mm512_castsi512_pd(_mm512_mask_xor_epi32(_mm512_undefined_epi32(), mask_loh, _mm512_castpd_si512(a), _mm512_castpd_si512(b)))
329 #define gmx_simd4_rsqrt_d(a)        _mm512_mask_cvtpslo_pd(_mm512_undefined_pd(), gmx_simd4_mask, _mm512_mask_rsqrt23_ps(_mm512_undefined_ps(), gmx_simd4_mask, _mm512_mask_cvtpd_pslo(_mm512_undefined_ps(), gmx_simd4_mask, x)))
330 #define gmx_simd4_fabs_d(x)         gmx_simd4_andnot_d(_mm512_set1_pd(GMX_DOUBLE_NEGZERO), x)
331 #define gmx_simd4_fneg_d(x)         _mm512_mask_addn_pd(_mm512_undefined_pd(), gmx_simd4_mask, x, _mm512_setzero_pd())
332 #define gmx_simd4_max_d(a, b)       _mm512_mask_gmax_pd(_mm512_undefined_pd(), gmx_simd4_mask, a, b)
333 #define gmx_simd4_min_d(a, b)       _mm512_mask_gmin_pd(_mm512_undefined_pd(), gmx_simd4_mask, a, b)
334 #define gmx_simd4_round_d(a)        _mm512_mask_roundfxpnt_adjust_pd(_mm512_undefined_pd(), gmx_simd4_mask, a, _MM_FROUND_TO_NEAREST_INT, _MM_EXPADJ_NONE)
335 #define gmx_simd4_trunc_d(a)        _mm512_mask_roundfxpnt_adjust_pd(_mm512_undefined_pd(), gmx_simd4_mask, a, _MM_FROUND_TO_ZERO, _MM_EXPADJ_NONE)
336 #define gmx_simd4_dotproduct3_d(a, b) _mm512_mask_reduce_add_pd(_mm512_int2mask(7), _mm512_mask_mul_pd(_mm512_undefined_pd(), _mm512_int2mask(7), a, b))
337 #define gmx_simd4_dbool_t           __mmask16
338 #define gmx_simd4_cmpeq_d(a, b)     _mm512_mask_cmp_pd_mask(gmx_simd4_mask, a, b, _CMP_EQ_OQ)
339 #define gmx_simd4_cmplt_d(a, b)     _mm512_mask_cmp_pd_mask(gmx_simd4_mask, a, b, _CMP_LT_OS)
340 #define gmx_simd4_cmple_d(a, b)     _mm512_mask_cmp_pd_mask(gmx_simd4_mask, a, b, _CMP_LE_OS)
341 #define gmx_simd4_and_db            _mm512_kand
342 #define gmx_simd4_or_db             _mm512_kor
343 #define gmx_simd4_anytrue_db(x)     (_mm512_mask2int(x)&0xF)
344 #define gmx_simd4_blendzero_d(a, sel)    _mm512_mask_mov_pd(_mm512_setzero_pd(), sel, a)
345 #define gmx_simd4_blendnotzero_d(a, sel) _mm512_mask_mov_pd(_mm512_setzero_pd(), _mm512_knot(sel), a)
346 #define gmx_simd4_blendv_d(a, b, sel)    _mm512_mask_blend_pd(sel, a, b)
347 #define gmx_simd4_reduce_d(x)       _mm512_mask_reduce_add_pd(_mm512_int2mask(0xF), x)
348
349 #define PERM_LOW2HIGH _MM_PERM_BABA
350 #define PERM_HIGH2LOW _MM_PERM_DCDC
351
352 #define mask_loh _mm512_int2mask(0x00FF) /* would be better a constant - but can't initialize with a function call. */
353 #define mask_hih _mm512_int2mask(0xFF00)
354
355 /* load store float */
356 static gmx_inline __m512 gmx_simdcall
357 gmx_simd_loadu_f_mic(const float * m)
358 {
359     return _mm512_loadunpackhi_ps(_mm512_loadunpacklo_ps(_mm512_undefined_ps(), m), m+16);
360 }
361
362 static gmx_inline void gmx_simdcall
363 gmx_simd_storeu_f_mic(float * m, __m512 s)
364 {
365     _mm512_packstorelo_ps(m, s);
366     _mm512_packstorehi_ps(m+16, s);
367 }
368
369 /* load store fint32 */
370 static gmx_inline __m512i gmx_simdcall
371 gmx_simd_loadu_fi_mic(const gmx_int32_t * m)
372 {
373     return _mm512_loadunpackhi_epi32(_mm512_loadunpacklo_epi32(_mm512_undefined_epi32(), m), m+16);
374 }
375
376 static gmx_inline void gmx_simdcall
377 gmx_simd_storeu_fi_mic(gmx_int32_t * m, __m512i s)
378 {
379     _mm512_packstorelo_epi32(m, s);
380     _mm512_packstorehi_epi32(m+16, s);
381 }
382
383 /* load store double */
384 static gmx_inline __m512d gmx_simdcall
385 gmx_simd_loadu_d_mic(const double * m)
386 {
387     return _mm512_loadunpackhi_pd(_mm512_loadunpacklo_pd(_mm512_undefined_pd(), m), m+8);
388 }
389
390 static gmx_inline void gmx_simdcall
391 gmx_simd_storeu_d_mic(double * m, __m512d s)
392 {
393     _mm512_packstorelo_pd(m, s);
394     _mm512_packstorehi_pd(m+8, s);
395 }
396
397 /* load store dint32 */
398 static gmx_inline __m512i gmx_simdcall
399 gmx_simd_loadu_di_mic(const gmx_int32_t * m)
400 {
401     return _mm512_mask_loadunpackhi_epi32(_mm512_mask_loadunpacklo_epi32(_mm512_undefined_epi32(), mask_loh, m), mask_loh, m+16);
402 }
403
404 static gmx_inline void gmx_simdcall
405 gmx_simd_storeu_di_mic(gmx_int32_t * m, __m512i s)
406 {
407     _mm512_mask_packstorelo_epi32(m, mask_loh, s);
408     _mm512_mask_packstorehi_epi32(m+16, mask_loh, s);
409 }
410
411 /* load store simd4 */
412 static gmx_inline __m512 gmx_simdcall
413 gmx_simd4_loadu_f_mic(const float * m)
414 {
415     return _mm512_mask_loadunpackhi_ps(_mm512_mask_loadunpacklo_ps(_mm512_undefined_ps(), gmx_simd4_mask, m), gmx_simd4_mask, m+16);
416 }
417
418 static gmx_inline void gmx_simdcall
419 gmx_simd4_storeu_f_mic(float * m, __m512 s)
420 {
421     _mm512_mask_packstorelo_ps(m, gmx_simd4_mask, s);
422     _mm512_mask_packstorehi_ps(m+16, gmx_simd4_mask, s);
423 }
424
425 static gmx_inline __m512d gmx_simdcall
426 gmx_simd4_loadu_d_mic(const double * m)
427 {
428     return _mm512_mask_loadunpackhi_pd(_mm512_mask_loadunpacklo_pd(_mm512_undefined_pd(), gmx_simd4_mask, m), gmx_simd4_mask, m+8);
429 }
430
431 static gmx_inline void gmx_simdcall
432 gmx_simd4_storeu_d_mic(double * m, __m512d s)
433 {
434     _mm512_mask_packstorelo_pd(m, gmx_simd4_mask, s);
435     _mm512_mask_packstorehi_pd(m+8, gmx_simd4_mask, s);
436 }
437
438 /* extract */
439 static gmx_inline gmx_int32_t gmx_simdcall
440 gmx_simd_extract_fi_mic(gmx_simd_fint32_t a, int index)
441 {
442     int r;
443     _mm512_mask_packstorelo_epi32(&r, _mm512_mask2int(1<<index), a);
444     return r;
445 }
446
447 static gmx_inline gmx_int32_t gmx_simdcall
448 gmx_simd_extract_di_mic(gmx_simd_dint32_t a, int index)
449 {
450     int r;
451     _mm512_mask_packstorelo_epi32(&r, _mm512_mask2int(1<<index), a);
452     return r;
453 }
454
455 /* This is likely faster than the built in scale operation (lat 8, t-put 3)
456  * since we only work on the integer part and use shifts. TODO: check. given that scale also only does integer
457  */
458 static gmx_inline __m512 gmx_simdcall
459 gmx_simd_set_exponent_f_mic(__m512 a)
460 {
461     __m512i       iexp         = gmx_simd_cvt_f2i(a);
462
463     const __m512i expbias      = _mm512_set1_epi32(127);
464     iexp = _mm512_slli_epi32(_mm512_add_epi32(iexp, expbias), 23);
465     return _mm512_castsi512_ps(iexp);
466
467     /* scale alternative:
468        return _mm512_scale_ps(_mm512_set1_ps(1), iexp);
469      */
470 }
471
472 static gmx_inline __m512d gmx_simdcall
473 gmx_simd_set_exponent_d_mic(__m512d a)
474 {
475     const __m512i expbias      = _mm512_set1_epi32(1023);
476     __m512i       iexp         = _mm512_cvtfxpnt_roundpd_epi32lo(a, _MM_FROUND_TO_NEAREST_INT);
477     iexp = _mm512_permutevar_epi32(_mm512_set_epi32(7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0), iexp);
478     iexp = _mm512_mask_slli_epi32(_mm512_setzero_epi32(), _mm512_int2mask(0xAAAA), _mm512_add_epi32(iexp, expbias), 20);
479     return _mm512_castsi512_pd(iexp);
480 }
481
482 static gmx_inline void gmx_simdcall
483 gmx_simd_cvt_f2dd_mic(__m512 f, __m512d * d0, __m512d * d1)
484 {
485     __m512i i1 = _mm512_permute4f128_epi32(_mm512_castps_si512(f), _MM_PERM_CDCD);
486
487     *d0 = _mm512_cvtpslo_pd(f);
488     *d1 = _mm512_cvtpslo_pd(_mm512_castsi512_ps(i1));
489 }
490
491 static gmx_inline __m512 gmx_simdcall
492 gmx_simd_cvt_dd2f_mic(__m512d d0, __m512d d1)
493 {
494     __m512 f0 = _mm512_cvtpd_pslo(d0);
495     __m512 f1 = _mm512_cvtpd_pslo(d1);
496     return _mm512_mask_permute4f128_ps(f0, mask_hih, f1, PERM_LOW2HIGH);
497 }
498
499 static gmx_inline __m512 gmx_simdcall
500 gmx_simd_exp2_f_mic(__m512 x)
501 {
502     return _mm512_exp223_ps(_mm512_cvtfxpnt_round_adjustps_epi32(x, _MM_ROUND_MODE_NEAREST, _MM_EXPADJ_24));
503 }
504
505 static gmx_inline __m512 gmx_simdcall
506 gmx_simd_exp_f_mic(__m512 x)
507 {
508     const gmx_simd_float_t  argscale    = gmx_simd_set1_f(1.44269504088896341f);
509     const gmx_simd_float_t  invargscale = gmx_simd_set1_f(-0.69314718055994528623f);
510     __m512                  xscaled     = _mm512_mul_ps(x, argscale);
511     __m512                  r           = gmx_simd_exp2_f_mic(xscaled);
512
513     /* gmx_simd_exp2_f_mic() provides 23 bits of accuracy, but we ruin some of that
514      * with the argument scaling due to single-precision rounding, where the
515      * rounding error is amplified exponentially. To correct this, we find the
516      * difference between the scaled argument and the true one (extended precision
517      * arithmetics does not appear to be necessary to fulfill our accuracy requirements)
518      * and then multiply by the exponent of this correction since exp(a+b)=exp(a)*exp(b).
519      * Note that this only adds two instructions (and maybe some constant loads).
520      */
521     x         = gmx_simd_fmadd_f(invargscale, xscaled, x);
522     /* x will now be a _very_ small number, so approximate exp(x)=1+x.
523      * We should thus apply the correction as r'=r*(1+x)=r+r*x
524      */
525     r         = gmx_simd_fmadd_f(r, x, r);
526     return r;
527 }
528
529 static gmx_inline __m512 gmx_simdcall
530 gmx_simd_log_f_mic(__m512 x)
531 {
532     return _mm512_mul_ps(_mm512_set1_ps(0.693147180559945286226764), _mm512_log2ae23_ps(x));
533 }
534
535 /* Function to check whether SIMD operations have resulted in overflow */
536 static int
537 gmx_simd_check_and_reset_overflow(void)
538 {
539     int                MXCSR;
540     int                sse_overflow;
541     /* The overflow flag is bit 3 in the register */
542     const unsigned int flag = 0x8;
543
544     MXCSR = _mm_getcsr();
545     if (MXCSR & flag)
546     {
547         sse_overflow = 1;
548         /* Set the overflow flag to zero */
549         MXCSR = MXCSR & ~flag;
550         _mm_setcsr(MXCSR);
551     }
552     else
553     {
554         sse_overflow = 0;
555     }
556     return sse_overflow;
557 }
558
559 #endif /* GMX_SIMD_IMPL_INTEL_MIC_H */