Add 32-bit ARM Neon SIMD support
[alexxy/gromacs.git] / cmake / gmxTestSimd.cmake
index 9d910073dc099e5f4757deb67c3e448e0b5c5f97..47213b67f08c01941ddad6d44ce1aa4815d0f780 100644 (file)
@@ -66,22 +66,20 @@ elseif(${GMX_SIMD} STREQUAL "SSE2")
 
     gmx_find_cflag_for_source(CFLAGS_SSE2 "C compiler SSE2 flag"
                               "#include<xmmintrin.h>
-                              int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_rsqrt_ps(x);return 0;}"
+                              int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_rsqrt_ps(x);return _mm_movemask_ps(x);}"
                               SIMD_C_FLAGS
-                              "-msse2" "/arch:SSE2")
+                              "-msse2" "/arch:SSE2" "-hgnu")
     gmx_find_cxxflag_for_source(CXXFLAGS_SSE2 "C++ compiler SSE2 flag"
                                 "#include<xmmintrin.h>
-                                int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_rsqrt_ps(x);return 0;}"
+                                int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_rsqrt_ps(x);return _mm_movemask_ps(x);}"
                                 SIMD_CXX_FLAGS
-                                "-msse2" "/arch:SSE2")
+                                "-msse2" "/arch:SSE2" "-hgnu")
 
     if(NOT CFLAGS_SSE2 OR NOT CXXFLAGS_SSE2)
         message(FATAL_ERROR "Cannot find SSE2 compiler flag. Use a newer compiler, or disable SIMD (slower).")
     endif()
 
     set(GMX_SIMD_X86_SSE2 1)
-    set(GMX_SIMD_X86_SSE2_OR_HIGHER 1)
-
     set(SIMD_STATUS_MESSAGE "Enabling SSE2 SIMD instructions")
 
 elseif(${GMX_SIMD} STREQUAL "SSE4.1")
@@ -89,14 +87,14 @@ elseif(${GMX_SIMD} STREQUAL "SSE4.1")
     # Note: MSVC enables SSE4.1 with the SSE2 flag, so we include that in testing.
     gmx_find_cflag_for_source(CFLAGS_SSE4_1 "C compiler SSE4.1 flag"
                               "#include<smmintrin.h>
-                              int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_dp_ps(x,x,0x77);return 0;}"
+                              int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_dp_ps(x,x,0x77);return _mm_movemask_ps(x);}"
                               SIMD_C_FLAGS
-                              "-msse4.1" "/arch:SSE4.1" "/arch:SSE2")
+                              "-msse4.1" "/arch:SSE4.1" "/arch:SSE2" "-hgnu")
     gmx_find_cxxflag_for_source(CXXFLAGS_SSE4_1 "C++ compiler SSE4.1 flag"
                                 "#include<smmintrin.h>
-                                int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_dp_ps(x,x,0x77);return 0;}"
+                                int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_dp_ps(x,x,0x77);return _mm_movemask_ps(x);}"
                                 SIMD_CXX_FLAGS
-                                "-msse4.1" "/arch:SSE4.1" "/arch:SSE2")
+                                "-msse4.1" "/arch:SSE4.1" "/arch:SSE2" "-hgnu")
 
     if(NOT CFLAGS_SSE4_1 OR NOT CXXFLAGS_SSE4_1)
         message(FATAL_ERROR "Cannot find SSE4.1 compiler flag. "
@@ -108,8 +106,6 @@ elseif(${GMX_SIMD} STREQUAL "SSE4.1")
     endif()
 
     set(GMX_SIMD_X86_SSE4_1 1)
-    set(GMX_SIMD_X86_SSE4_1_OR_HIGHER 1)
-    set(GMX_SIMD_X86_SSE2_OR_HIGHER   1)
     set(SIMD_STATUS_MESSAGE "Enabling SSE4.1 SIMD instructions")
 
 elseif(${GMX_SIMD} STREQUAL "AVX_128_FMA")
@@ -126,12 +122,12 @@ elseif(${GMX_SIMD} STREQUAL "AVX_128_FMA")
                               "#include<immintrin.h>
                               int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_permute_ps(x,1);return 0;}"
                               SIMD_C_FLAGS
-                              "-mavx" "/arch:AVX")
+                              "-mavx" "/arch:AVX" "-hgnu")
     gmx_find_cxxflag_for_source(CXXFLAGS_AVX_128 "C++ compiler AVX (128 bit) flag"
                                 "#include<immintrin.h>
                                 int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_permute_ps(x,1);return 0;}"
                                 SIMD_CXX_FLAGS
-                                "-mavx" "/arch:AVX")
+                                "-mavx" "/arch:AVX" "-hgnu")
 
     ### STAGE 2: Find the fused-multiply add flag.
     # GCC requires x86intrin.h for FMA support. MSVC 2010 requires intrin.h for FMA support.
@@ -148,16 +144,16 @@ elseif(${GMX_SIMD} STREQUAL "AVX_128_FMA")
 "#include<immintrin.h>
 ${INCLUDE_X86INTRIN_H}
 ${INCLUDE_INTRIN_H}
-int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_macc_ps(x,x,x);return 0;}"
+int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_macc_ps(x,x,x);return _mm_movemask_ps(x);}"
                               SIMD_C_FLAGS
-                              "-mfma4")
+                              "-mfma4" "-hgnu")
     gmx_find_cxxflag_for_source(CXXFLAGS_AVX_128_FMA "C++ compiler AVX (128 bit) FMA4 flag"
 "#include<immintrin.h>
 ${INCLUDE_X86INTRIN_H}
 ${INCLUDE_INTRIN_H}
-int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_macc_ps(x,x,x);return 0;}"
+int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_macc_ps(x,x,x);return _mm_movemask_ps(x);}"
                                 SIMD_CXX_FLAGS
-                                "-mfma4")
+                                "-mfma4" "-hgnu")
 
     # We only need to check the last (FMA) test; that will always fail if the basic AVX128 test failed
     if(NOT CFLAGS_AVX_128_FMA OR NOT CXXFLAGS_AVX_128_FMA)
@@ -169,14 +165,14 @@ int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_macc_ps(x,x,x);return 0;}"
 "#include<immintrin.h>
 ${INCLUDE_X86INTRIN_H}
 ${INCLUDE_INTRIN_H}
-int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_frcz_ps(x);return 0;}"
+int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_frcz_ps(x);return _mm_movemask_ps(x);}"
                               SIMD_C_FLAGS
                               "-mxop")
     gmx_find_cxxflag_for_source(CXXFLAGS_AVX_128_XOP "C++ compiler AVX (128 bit) XOP flag"
 "#include<immintrin.h>
 ${INCLUDE_X86INTRIN_H}
 ${INCLUDE_INTRIN_H}
-int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_frcz_ps(x);return 0;}"
+int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_frcz_ps(x);return _mm_movemask_ps(x);}"
                                 SIMD_CXX_FLAGS
                                 "-mxop")
 
@@ -203,10 +199,6 @@ int main(){__m128 x=_mm_set1_ps(0.5);x=_mm_frcz_ps(x);return 0;}"
     gmx_test_avx_gcc_maskload_bug(GMX_SIMD_X86_AVX_GCC_MASKLOAD_BUG "${SIMD_C_FLAGS}")
 
     set(GMX_SIMD_X86_AVX_128_FMA 1)
-    set(GMX_SIMD_X86_AVX_128_FMA_OR_HIGHER 1)
-    set(GMX_SIMD_X86_SSE4_1_OR_HIGHER      1)
-    set(GMX_SIMD_X86_SSE2_OR_HIGHER        1)
-
     set(SIMD_STATUS_MESSAGE "Enabling 128-bit AVX SIMD Gromacs SIMD (with fused-multiply add)")
 
 elseif(${GMX_SIMD} STREQUAL "AVX_256")
@@ -215,14 +207,14 @@ elseif(${GMX_SIMD} STREQUAL "AVX_256")
 
     gmx_find_cflag_for_source(CFLAGS_AVX "C compiler AVX (256 bit) flag"
                               "#include<immintrin.h>
-                              int main(){__m256 x=_mm256_set1_ps(0.5);x=_mm256_add_ps(x,x);return 0;}"
+                              int main(){__m256 x=_mm256_set1_ps(0.5);x=_mm256_add_ps(x,x);return _mm256_movemask_ps(x);}"
                               SIMD_C_FLAGS
-                              "-mavx" "/arch:AVX")
+                              "-mavx" "/arch:AVX" "-hgnu")
     gmx_find_cxxflag_for_source(CXXFLAGS_AVX "C++ compiler AVX (256 bit) flag"
                                 "#include<immintrin.h>
-                                int main(){__m256 x=_mm256_set1_ps(0.5);x=_mm256_add_ps(x,x);return 0;}"
+                                int main(){__m256 x=_mm256_set1_ps(0.5);x=_mm256_add_ps(x,x);return _mm256_movemask_ps(x);}"
                                 SIMD_CXX_FLAGS
-                                "-mavx" "/arch:AVX")
+                                "-mavx" "/arch:AVX" "-hgnu")
 
     if(NOT CFLAGS_AVX OR NOT CXXFLAGS_AVX)
         message(FATAL_ERROR "Cannot find AVX compiler flag. Use a newer compiler, or choose SSE4.1 SIMD (slower).")
@@ -231,29 +223,22 @@ elseif(${GMX_SIMD} STREQUAL "AVX_256")
     gmx_test_avx_gcc_maskload_bug(GMX_SIMD_X86_AVX_GCC_MASKLOAD_BUG "${SIMD_C_FLAGS}")
 
     set(GMX_SIMD_X86_AVX_256 1)
-    set(GMX_SIMD_X86_AVX_256_OR_HIGHER  1)
-    set(GMX_SIMD_X86_SSE4_1_OR_HIGHER   1)
-    set(GMX_SIMD_X86_SSE2_OR_HIGHER     1)
-
     set(SIMD_STATUS_MESSAGE "Enabling 256-bit AVX SIMD instructions")
 
 elseif(${GMX_SIMD} STREQUAL "AVX2_256")
 
-    # Comment out this line for AVX2 development
-    message(FATAL_ERROR "AVX2_256 is disabled until the implementation has been commited.")
-
     gmx_use_clang_as_with_gnu_compilers_on_osx()
 
     gmx_find_cflag_for_source(CFLAGS_AVX2 "C compiler AVX2 flag"
                               "#include<immintrin.h>
-                              int main(){__m256 x=_mm256_set1_ps(0.5);x=_mm256_fmadd_ps(x,x,x);return 0;}"
+                              int main(){__m256i x=_mm256_set1_epi32(5);x=_mm256_add_epi32(x,x);return _mm256_movemask_epi8(x);}"
                               SIMD_C_FLAGS
-                              "-march=core-avx2" "-mavx2" "/arch:AVX") # no AVX2-specific flag for MSVC yet
+                              "-march=core-avx2" "-mavx2" "/arch:AVX" "-hgnu") # no AVX2-specific flag for MSVC yet
     gmx_find_cxxflag_for_source(CXXFLAGS_AVX2 "C++ compiler AVX2 flag"
                                 "#include<immintrin.h>
-                                int main(){__m256 x=_mm256_set1_ps(0.5);x=_mm256_fmadd_ps(x,x,x);return 0;}"
+                                int main(){__m256i x=_mm256_set1_epi32(5);x=_mm256_add_epi32(x,x);return _mm256_movemask_epi8(x);}"
                                 SIMD_CXX_FLAGS
-                                "-march=core-avx2" "-mavx2" "/arch:AVX") # no AVX2-specific flag for MSVC yet
+                                "-march=core-avx2" "-mavx2" "/arch:AVX" "-hgnu") # no AVX2-specific flag for MSVC yet
 
     if(NOT CFLAGS_AVX2 OR NOT CXXFLAGS_AVX2)
         message(FATAL_ERROR "Cannot find AVX2 compiler flag. Use a newer compiler, or choose AVX SIMD (slower).")
@@ -262,13 +247,28 @@ elseif(${GMX_SIMD} STREQUAL "AVX2_256")
     # No need to test for Maskload bug - it was fixed before gcc added AVX2 support
 
     set(GMX_SIMD_X86_AVX2_256 1)
-    set(GMX_SIMD_X86_AVX2_256_OR_HIGHER 1)
-    set(GMX_SIMD_X86_AVX_256_OR_HIGHER  1)
-    set(GMX_SIMD_X86_SSE4_1_OR_HIGHER   1)
-    set(GMX_SIMD_X86_SSE2_OR_HIGHER     1)
-
     set(SIMD_STATUS_MESSAGE "Enabling 256-bit AVX2 SIMD instructions")
 
+elseif(${GMX_SIMD} STREQUAL "ARM_NEON")
+
+    gmx_find_cflag_for_source(CFLAGS_ARM_NEON "C compiler 32-bit ARM NEON flag"
+                              "#include<arm_neon.h>
+                              int main(){float32x4_t x=vdupq_n_f32(0.5);x=vmlaq_f32(x,x,x);return vgetq_lane_f32(x,0)>0;}"
+                              SIMD_C_FLAGS
+                              "-mfpu=neon" "")
+    gmx_find_cxxflag_for_source(CXXFLAGS_ARM_NEON "C++ compiler 32-bit ARM NEON flag"
+                                "#include<arm_neon.h>
+                                int main(){float32x4_t x=vdupq_n_f32(0.5);x=vmlaq_f32(x,x,x);return vgetq_lane_f32(x,0)>0;}"
+                                SIMD_CXX_FLAGS
+                                "-mfpu=neon" "")
+
+    if(NOT CFLAGS_ARM_NEON OR NOT CXXFLAGS_ARM_NEON)
+        message(FATAL_ERROR "Cannot find ARM 32-bit NEON compiler flag. Use a newer compiler, or disable NEON SIMD.")
+    endif()
+
+    set(GMX_SIMD_ARM_NEON 1)
+    set(SIMD_STATUS_MESSAGE "Enabling 32-bit ARM NEON SIMD instructions")
+
 elseif(${GMX_SIMD} STREQUAL "IBM_QPX")
 
     try_compile(TEST_QPX ${CMAKE_BINARY_DIR}
@@ -285,27 +285,25 @@ elseif(${GMX_SIMD} STREQUAL "IBM_QPX")
 
 elseif(${GMX_SIMD} STREQUAL "SPARC64_HPC_ACE")
 
+    # Note that GMX_RELAXED_DOUBLE_PRECISION is enabled by default in the top-level CMakeLists.txt
+
     set(GMX_SIMD_SPARC64_HPC_ACE 1)
     set(SIMD_STATUS_MESSAGE "Enabling Sparc64 HPC-ACE SIMD instructions")
 
 elseif(${GMX_SIMD} STREQUAL "REFERENCE")
 
-    add_definitions(-DGMX_SIMD_REFERENCE)
-    if(${GMX_NBNXN_REF_KERNEL_TYPE} STREQUAL "4xn")
-        if(${GMX_NBNXN_REF_KERNEL_WIDTH} STREQUAL "2" OR ${GMX_NBNXN_REF_KERNEL_WIDTH} STREQUAL "4" OR ${GMX_NBNXN_REF_KERNEL_WIDTH} STREQUAL "8")
-            add_definitions(-DGMX_NBNXN_SIMD_4XN -DGMX_SIMD_REF_WIDTH=${GMX_NBNXN_REF_KERNEL_WIDTH})
-        else()
-            message(FATAL_ERROR "Unsupported width for 4xn reference kernels")
-        endif()
-    elseif(${GMX_NBNXN_REF_KERNEL_TYPE} STREQUAL "2xnn")
-        if(${GMX_NBNXN_REF_KERNEL_WIDTH} STREQUAL "8" OR ${GMX_NBNXN_REF_KERNEL_WIDTH} STREQUAL "16")
-            add_definitions(-DGMX_NBNXN_SIMD_2XNN -DGMX_SIMD_REF_WIDTH=${GMX_NBNXN_REF_KERNEL_WIDTH})
-        else()
-            message(FATAL_ERROR "Unsupported width for 2xn reference kernels")
-        endif()
-    else()
-        message(FATAL_ERROR "Unsupported kernel type")
+    # NB: This file handles settings for the SIMD module, so in the interest 
+    # of proper modularization, please do NOT put any verlet kernel settings in this file.
+
+    if(GMX_SIMD_REF_FLOAT_WIDTH)
+        add_definitions(-DGMX_SIMD_REF_FLOAT_WIDTH=${GMX_SIMD_REF_FLOAT_WIDTH})
     endif()
+    if(GMX_SIMD_REF_DOUBLE_WIDTH)
+       add_definitions(-DGMX_SIMD_REF_DOUBLE_WIDTH=${GMX_SIMD_REF_DOUBLE_WIDTH})
+    endif()
+
+    set(GMX_SIMD_REFERENCE 1)
+    set(SIMD_STATUS_MESSAGE "Enabling reference (emulated) SIMD instructions.")
 
 else()
     gmx_invalid_option_value(GMX_SIMD)
@@ -317,5 +315,24 @@ if (SIMD_CHANGED AND DEFINED SIMD_STATUS_MESSAGE)
     message(STATUS "${SIMD_STATUS_MESSAGE}")
 endif()
 
+# By default, 32-bit windows cannot pass SIMD (SSE/AVX) arguments in registers,
+# and even on 64-bit (all platforms) it is only used for a handful of arguments.
+# The __vectorcall (MSVC, from MSVC2013) or __regcall (ICC) calling conventions
+# enable this, which is critical to enable 32-bit SIMD and improves performance
+# for 64-bit SIMD.
+# Check if the compiler supports one of these, and in that case set gmx_simdcall
+# to that string. If we do not have any such calling convention modifier, set it
+# to an empty string.
+if(NOT DEFINED GMX_SIMD_CALLING_CONVENTION)
+    foreach(callconv __vectorcall __regcall "")
+        set(callconv_compile_var "_callconv_${callconv}")
+        check_c_source_compiles("int ${callconv} f(int i) {return i;} int main(void) {return f(0);}" ${callconv_compile_var})
+        if(${callconv_compile_var})
+            set(GMX_SIMD_CALLING_CONVENTION "${callconv}" CACHE INTERNAL "Calling convention for SIMD routines" FORCE)
+            break()
+        endif()
+    endforeach()
+endif()
+
 endmacro()