Skip to content
Snippets Groups Projects

Support ARM64 Streaming SVE

Merged Michael Kuron requested to merge sme into master
All threads resolved!
Files
8
@@ -16,9 +16,9 @@ def get_argument_string(function_shortcut, first=''):
def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
if instruction_set != 'neon' and not instruction_set.startswith('sve'):
if instruction_set not in ['neon', 'sme'] and not instruction_set.startswith('sve'):
raise NotImplementedError(instruction_set)
if instruction_set == 'sve':
if instruction_set in ['sve', 'sme']:
cmp = 'cmp'
elif instruction_set.startswith('sve'):
cmp = 'cmp'
@@ -52,7 +52,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result = dict()
if instruction_set == 'sve':
if instruction_set in ['sve', 'sme']:
width = 'svcntd()' if data_type == 'double' else 'svcntw()'
intwidth = 'svcntw()'
result['bytes'] = 'svcntb()'
@@ -60,14 +60,14 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
width = bitwidth // bits[data_type]
intwidth = bitwidth // bits['int']
result['bytes'] = bitwidth // 8
if instruction_set.startswith('sve'):
if instruction_set.startswith('sve') or instruction_set == 'sme':
prefix = 'sv'
suffix = f'_f{bits[data_type]}'
elif instruction_set == 'neon':
prefix = 'v'
suffix = f'q_f{bits[data_type]}'
if instruction_set == 'sve':
if instruction_set in ['sve', 'sme']:
predicate = f'{prefix}whilelt_b{bits[data_type]}_u64({{loop_counter}}, {{loop_stop}})'
int_predicate = f'{prefix}whilelt_b{bits["int"]}_u64({{loop_counter}}, {{loop_stop}})'
else:
@@ -86,7 +86,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result[intrinsic_id] = prefix + name + suffix + undef + arg_string
if instruction_set == 'sve':
if instruction_set in ['sve', 'sme']:
from pystencils.backends.cbackend import CFunction
result['width'] = CFunction(width, "int")
result['intwidth'] = CFunction(intwidth, "int")
@@ -94,23 +94,24 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['width'] = width
result['intwidth'] = intwidth
if instruction_set.startswith('sve'):
if instruction_set.startswith('sve') or instruction_set == 'sme':
result['makeVecConst'] = f'svdup_f{bits[data_type]}' + '({0})'
result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})'
result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})'
vindex = f'svindex_u{bits[data_type]}(0, {{0}})'
result['storeS'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{2}") + ', {1})'
result['loadS'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{1}") + ')'
if instruction_set != 'sme':
vindex = f'svindex_u{bits[data_type]}(0, {{0}})'
result['storeS'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{2}") + ', {1})'
result['loadS'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{1}") + ')'
result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})"
result['float'] = f'svfloat{bits["float"]}_{"s" if instruction_set != "sve" else ""}t'
result['double'] = f'svfloat{bits["double"]}_{"s" if instruction_set != "sve" else ""}t'
result['int'] = f'svint{bits["int"]}_{"s" if instruction_set != "sve" else ""}t'
result['bool'] = f'svbool_{"s" if instruction_set != "sve" else ""}t'
result['float'] = f'svfloat{bits["float"]}_{"s" if instruction_set not in ["sve", "sme"] else ""}t'
result['double'] = f'svfloat{bits["double"]}_{"s" if instruction_set not in ["sve", "sme"] else ""}t'
result['int'] = f'svint{bits["int"]}_{"s" if instruction_set not in ["sve", "sme"] else ""}t'
result['bool'] = f'svbool_{"s" if instruction_set not in ["sve", "sme"] else ""}t'
result['headers'] = ['<arm_sve.h>', '"arm_neon_helpers.h"']
@@ -121,9 +122,12 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}'
result['maskStoreU'] = result['storeU'].replace(predicate, '{2}')
result['maskStoreS'] = result['storeS'].replace(predicate, '{3}')
if instruction_set != 'sme':
result['maskStoreS'] = result['storeS'].replace(predicate, '{3}')
if instruction_set != 'sve':
if instruction_set == 'sme':
result['function_prefix'] = '__attribute__((arm_locally_streaming))'
elif instruction_set not in ['sve', 'sme']:
result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
else:
result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})'
Loading