Skip to content
Snippets Groups Projects
Commit 393cf72a authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

ARM NEON vectorization

parent 2d758462
1 merge request!187WIP: ARM NEON vectorization
Pipeline #27982 failed with stage
in 6 minutes and 12 seconds
......@@ -531,6 +531,11 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert self.instruction_set['width'] == expr_type.width
return None
def _print_Abs(self, expr):
if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access):
return self.instruction_set['abs'].format(self._print(expr.args[0]))
return super()._print_Abs(expr)
def _print_Function(self, expr):
if isinstance(expr, vector_memory_access):
arg, data_type, aligned, _, mask = expr.args
......
def get_vector_instruction_set(data_type='double', instruction_set='avx'):
if instruction_set in ['neon', 'sve']:
return get_vector_instruction_set_arm(data_type, instruction_set)
else:
return get_vector_instruction_set_x86(data_type, instruction_set)
# noinspection SpellCheckingInspection
def get_vector_instruction_set(data_type='double', instruction_set='avx'):
def get_vector_instruction_set_x86(data_type, instruction_set):
comparisons = {
'==': '_CMP_EQ_UQ',
'!=': '_CMP_NEQ_UQ',
......@@ -137,7 +142,8 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
result['any'] = '!_ktestz_mask%d_u8({0}, {0})' % (size, )
result['all'] = '_kortestc_mask%d_u8({0}, {0})' % (size, )
result['blendv'] = '%s_mask_blend_%s({2}, {0}, {1})' % (pre, suf)
result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,)
result['rsqrt'] = "%s_rsqrt14_%s({0})" % (pre, suf)
result['abs'] = "%s_abs_%s({0})" % (pre, suf)
result['bool'] = "__mmask%d" % (size,)
params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)])
......@@ -146,7 +152,57 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
result['makeVecConstBool'] = f"__mmask8(({params}) )"
if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = "_mm256_rsqrt_ps({0})"
result['rsqrt'] = "%s_rsqrt_%s({0})" % (pre, suf)
return result
def get_vector_instruction_set_arm(data_type, instruction_set):
size = 64 if data_type == 'double' else 32
ops = {
'*': ('mul', 2), '+': ('add', 2), '-': ('sub', 2), '/': ('div', 2),
'sqrt': ('sqrt', 1), 'rsqrt': (None, 1), 'abs': ('abs', 1),
'==': ('ceq', 2), '<=': ('cle', 2), '<': ('clt', 2), '>=': ('cge', 2), '>': ('cgt', 2),
'&': ('and', 2), '|': ('orr', 2),
'storeU': ('st1', 2), 'loadU': ('ld1', 1), 'store': ('st1', 2), 'load': ('ld1', 1),
}
if instruction_set == 'neon':
width = {
("double", "neon"): 2,
("float", "neon"): 4
}
result = {
'width': width[(data_type, instruction_set)],
'headers': ['<arm_neon.h>']
}
result['double'] = "float64x%d_t" % (width[('double', instruction_set)])
result['float'] = "float32x%d_t" % (width[('float', instruction_set)])
elif instruction_set == 'sve':
result = {
'headers': ['<arm_sve.h>']
}
result['double'] = "svfloat64_t"
result['float'] = "svfloat32_t"
for op, instr in ops.items():
instr, arity = instr
if instr:
result[op] = 'v%sq_f%d({%s})' % (instr, size, '},{'.join([str(i) for i in range(arity)]))
if instruction_set == 'sve':
result[op] = 's' + result[op]
else:
result[op] = None
if instruction_set == 'sve':
result['!='] = 'svnot_u%d(%s)' % (size, result['=='])
else:
result['!='] = 'vmvnq_u%d(%s)' % (size, result['=='])
result['stream'] = '__builtin_nontemporal_store({1}, {0})'
return result
......@@ -162,6 +218,7 @@ def get_supported_instruction_sets():
required_sse_flags = {'sse', 'sse2', 'ssse3', 'sse4_1', 'sse4_2'}
required_avx_flags = {'avx'}
required_avx512_flags = {'avx512f'}
required_neon_flags = {'neon'}
flags = set(get_cpu_info()['flags'])
if flags.issuperset(required_sse_flags):
result.append("sse")
......@@ -169,4 +226,6 @@ def get_supported_instruction_sets():
result.append("avx")
if flags.issuperset(required_avx512_flags):
result.append("avx512")
if flags.issuperset(required_neon_flags):
result.append("neon")
return result
......@@ -176,7 +176,7 @@ def insert_vector_casts(ast_node):
visit_expr(expr.args[4]))
elif isinstance(expr, cast_func):
return expr
elif expr.func is sp.Abs:
elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set:
new_arg = visit_expr(expr.args[0])
pw = sp.Piecewise((-1 * new_arg, new_arg < 0), (new_arg, True))
return visit_expr(pw)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment