From 393cf72a7e5f73818f602cb2ba8a5f3e6158b5e8 Mon Sep 17 00:00:00 2001 From: Michael Kuron <mkuron@icp.uni-stuttgart.de> Date: Sat, 14 Nov 2020 15:35:18 +0100 Subject: [PATCH] ARM NEON vectorization --- pystencils/backends/cbackend.py | 5 ++ pystencils/backends/simd_instruction_sets.py | 65 +++++++++++++++++++- pystencils/cpu/vectorization.py | 2 +- 3 files changed, 68 insertions(+), 4 deletions(-) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 94f347729..cb782a1cc 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -531,6 +531,11 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): assert self.instruction_set['width'] == expr_type.width return None + def _print_Abs(self, expr): + if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access): + return self.instruction_set['abs'].format(self._print(expr.args[0])) + return super()._print_Abs(expr) + def _print_Function(self, expr): if isinstance(expr, vector_memory_access): arg, data_type, aligned, _, mask = expr.args diff --git a/pystencils/backends/simd_instruction_sets.py b/pystencils/backends/simd_instruction_sets.py index 7e2be4dee..0a55abb27 100644 --- a/pystencils/backends/simd_instruction_sets.py +++ b/pystencils/backends/simd_instruction_sets.py @@ -1,7 +1,12 @@ +def get_vector_instruction_set(data_type='double', instruction_set='avx'): + if instruction_set in ['neon', 'sve']: + return get_vector_instruction_set_arm(data_type, instruction_set) + else: + return get_vector_instruction_set_x86(data_type, instruction_set) # noinspection SpellCheckingInspection -def get_vector_instruction_set(data_type='double', instruction_set='avx'): +def get_vector_instruction_set_x86(data_type, instruction_set): comparisons = { '==': '_CMP_EQ_UQ', '!=': '_CMP_NEQ_UQ', @@ -137,7 +142,8 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): result['any'] = '!_ktestz_mask%d_u8({0}, {0})' % (size, ) result['all'] = '_kortestc_mask%d_u8({0}, {0})' % (size, ) result['blendv'] = '%s_mask_blend_%s({2}, {0}, {1})' % (pre, suf) - result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,) + result['rsqrt'] = "%s_rsqrt14_%s({0})" % (pre, suf) + result['abs'] = "%s_abs_%s({0})" % (pre, suf) result['bool'] = "__mmask%d" % (size,) params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)]) @@ -146,7 +152,57 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): result['makeVecConstBool'] = f"__mmask8(({params}) )" if instruction_set == 'avx' and data_type == 'float': - result['rsqrt'] = "_mm256_rsqrt_ps({0})" + result['rsqrt'] = "%s_rsqrt_%s({0})" % (pre, suf) + + return result + + +def get_vector_instruction_set_arm(data_type, instruction_set): + size = 64 if data_type == 'double' else 32 + + ops = { + '*': ('mul', 2), '+': ('add', 2), '-': ('sub', 2), '/': ('div', 2), + 'sqrt': ('sqrt', 1), 'rsqrt': (None, 1), 'abs': ('abs', 1), + '==': ('ceq', 2), '<=': ('cle', 2), '<': ('clt', 2), '>=': ('cge', 2), '>': ('cgt', 2), + '&': ('and', 2), '|': ('orr', 2), + 'storeU': ('st1', 2), 'loadU': ('ld1', 1), 'store': ('st1', 2), 'load': ('ld1', 1), + } + + if instruction_set == 'neon': + width = { + ("double", "neon"): 2, + ("float", "neon"): 4 + } + + result = { + 'width': width[(data_type, instruction_set)], + 'headers': ['<arm_neon.h>'] + } + + result['double'] = "float64x%d_t" % (width[('double', instruction_set)]) + result['float'] = "float32x%d_t" % (width[('float', instruction_set)]) + elif instruction_set == 'sve': + result = { + 'headers': ['<arm_sve.h>'] + } + + result['double'] = "svfloat64_t" + result['float'] = "svfloat32_t" + + for op, instr in ops.items(): + instr, arity = instr + if instr: + result[op] = 'v%sq_f%d({%s})' % (instr, size, '},{'.join([str(i) for i in range(arity)])) + if instruction_set == 'sve': + result[op] = 's' + result[op] + else: + result[op] = None + + if instruction_set == 'sve': + result['!='] = 'svnot_u%d(%s)' % (size, result['==']) + else: + result['!='] = 'vmvnq_u%d(%s)' % (size, result['==']) + result['stream'] = '__builtin_nontemporal_store({1}, {0})' return result @@ -162,6 +218,7 @@ def get_supported_instruction_sets(): required_sse_flags = {'sse', 'sse2', 'ssse3', 'sse4_1', 'sse4_2'} required_avx_flags = {'avx'} required_avx512_flags = {'avx512f'} + required_neon_flags = {'neon'} flags = set(get_cpu_info()['flags']) if flags.issuperset(required_sse_flags): result.append("sse") @@ -169,4 +226,6 @@ def get_supported_instruction_sets(): result.append("avx") if flags.issuperset(required_avx512_flags): result.append("avx512") + if flags.issuperset(required_neon_flags): + result.append("neon") return result diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 0ee5200a0..cf5145656 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -176,7 +176,7 @@ def insert_vector_casts(ast_node): visit_expr(expr.args[4])) elif isinstance(expr, cast_func): return expr - elif expr.func is sp.Abs: + elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set: new_arg = visit_expr(expr.args[0]) pw = sp.Piecewise((-1 * new_arg, new_arg < 0), (new_arg, True)) return visit_expr(pw) -- GitLab