From d01fc61c4201f11f70828d45b8a5917d6f2e7887 Mon Sep 17 00:00:00 2001 From: Michael Kuron <m.kuron@gmx.de> Date: Mon, 27 Mar 2023 19:56:15 +0200 Subject: [PATCH] Properly detect and enable vectorization on ARM --- pystencils/backends/simd_instruction_sets.py | 5 ++--- pystencils/cpu/cpujit.py | 4 +--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/pystencils/backends/simd_instruction_sets.py b/pystencils/backends/simd_instruction_sets.py index cdb2ee5cf..d8cccf98a 100644 --- a/pystencils/backends/simd_instruction_sets.py +++ b/pystencils/backends/simd_instruction_sets.py @@ -43,8 +43,7 @@ def get_supported_instruction_sets(): return _cache.copy() if 'PYSTENCILS_SIMD' in os.environ: return os.environ['PYSTENCILS_SIMD'].split(',') - if (platform.system() == 'Darwin' or platform.system() == 'Linux') and platform.machine() == 'arm64': - # not supported by cpuinfo + if platform.system() == 'Darwin' and platform.machine() == 'arm64': # not supported by cpuinfo return ['neon'] elif platform.system() == 'Linux' and platform.machine().startswith('riscv'): # not supported by cpuinfo libc = CDLL('libc.so.6') @@ -72,7 +71,7 @@ def get_supported_instruction_sets(): required_sse_flags = {'sse', 'sse2', 'ssse3', 'sse4_1', 'sse4_2'} required_avx_flags = {'avx', 'avx2'} required_avx512_flags = {'avx512f'} - required_neon_flags = {'neon'} + required_neon_flags = {'asimd'} required_sve_flags = {'sve'} flags = set(get_cpu_info()['flags']) if flags.issuperset(required_sse_flags): diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py index aebefec91..c71700d2f 100644 --- a/pystencils/cpu/cpujit.py +++ b/pystencils/cpu/cpujit.py @@ -146,9 +146,7 @@ def read_config(): ('flags', '-Ofast -DNDEBUG -fPIC -march=native -fopenmp -std=c++11'), ('restrict_qualifier', '__restrict__') ]) - if platform.machine() == 'arm64': - default_compiler_config['flags'] = default_compiler_config['flags'].replace('-march=native', '') - elif platform.machine().startswith('ppc64'): + if platform.machine().startswith('ppc64') or platform.machine() == 'arm64': default_compiler_config['flags'] = default_compiler_config['flags'].replace('-march=native', '-mcpu=native') elif platform.system().lower() == 'windows': -- GitLab