Skip to content
Snippets Groups Projects
Commit a946d58e authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

gather/scatter vector instructions

parent 6d08ea7a
No related branches found
No related tags found
1 merge request!241Vector scatter/gather support
Pipeline #31780 passed
...@@ -88,9 +88,16 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -88,9 +88,16 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})' result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})'
result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})' result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})'
vindex = f'svindex_u{bits[data_type]}(0, {{0}})'
result['scatter'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{2}") + ', {1})'
result['gather'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{1}") + ')'
result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})" result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})"
result[data_type] = f'svfloat{bits[data_type]}_st' result['float'] = 'svfloat32_st'
result['double'] = 'svfloat64_st'
result['int'] = f'svint{bits["int"]}_st' result['int'] = f'svint{bits["int"]}_st'
result['bool'] = 'svbool_st' result['bool'] = 'svbool_st'
...@@ -102,6 +109,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -102,6 +109,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['any'] = f'svptest_any({predicate}, {{0}})' result['any'] = f'svptest_any({predicate}, {{0}})'
result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}' result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}'
result['maskStoreU'] = result['storeU'].replace(predicate, '{2}')
result['maskStoreA'] = result['storeA'].replace(predicate, '{2}')
result['maskScatter'] = result['scatter'].replace(predicate, '{3}')
result['compile_flags'] = [f'-msve-vector-bits={bitwidth}'] result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
else: else:
result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})' result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})'
......
...@@ -130,7 +130,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): ...@@ -130,7 +130,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
result['all'] = f"{pre}_movemask_{suf}({{0}}) == {hex(2**result['width']-1)}" result['all'] = f"{pre}_movemask_{suf}({{0}}) == {hex(2**result['width']-1)}"
if instruction_set == 'avx512': if instruction_set == 'avx512':
size = 8 if data_type == 'double' else 16 size = result['width']
result['&'] = f'_kand_mask{size}({{0}}, {{1}})' result['&'] = f'_kand_mask{size}({{0}}, {{1}})'
result['|'] = f'_kor_mask{size}({{0}}, {{1}})' result['|'] = f'_kor_mask{size}({{0}}, {{1}})'
result['any'] = f'!_ktestz_mask{size}_u8({{0}}, {{0}})' result['any'] = f'!_ktestz_mask{size}_u8({{0}}, {{0}})'
...@@ -145,6 +145,14 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): ...@@ -145,6 +145,14 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)]) params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)])
result['makeVecConstBool'] = f"__mmask8(({params}) )" result['makeVecConstBool'] = f"__mmask8(({params}) )"
vindex = f'{pre}_set_epi{bit_width//size}(' + ', '.join([str(i) for i in range(result['width'])][::-1]) + ')'
vindex = f'{pre}_mullo_epi{bit_width//size}({vindex}, {pre}_set1_epi{bit_width//size}({{0}}))'
result['scatter'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \
f', {{1}}, {64//size})'
result['maskScatter'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \
f', {{1}}, {64//size})'
result['gather'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {64//size})'
if instruction_set == 'avx' and data_type == 'float': if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})" result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment