From a946d58ea90b43d8e43d1554d7a2c502bb5334ec Mon Sep 17 00:00:00 2001 From: Michael Kuron <m.kuron@gmx.de> Date: Wed, 28 Apr 2021 19:29:12 +0200 Subject: [PATCH] gather/scatter vector instructions --- pystencils/backends/arm_instruction_sets.py | 13 ++++++++++++- pystencils/backends/x86_instruction_sets.py | 10 +++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/pystencils/backends/arm_instruction_sets.py b/pystencils/backends/arm_instruction_sets.py index 9f7b4ee22..5318dffeb 100644 --- a/pystencils/backends/arm_instruction_sets.py +++ b/pystencils/backends/arm_instruction_sets.py @@ -88,9 +88,16 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})' result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})' + vindex = f'svindex_u{bits[data_type]}(0, {{0}})' + result['scatter'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \ + vindex.format("{2}") + ', {1})' + result['gather'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \ + vindex.format("{1}") + ')' + result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})" - result[data_type] = f'svfloat{bits[data_type]}_st' + result['float'] = 'svfloat32_st' + result['double'] = 'svfloat64_st' result['int'] = f'svint{bits["int"]}_st' result['bool'] = 'svbool_st' @@ -102,6 +109,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): result['any'] = f'svptest_any({predicate}, {{0}})' result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}' + result['maskStoreU'] = result['storeU'].replace(predicate, '{2}') + result['maskStoreA'] = result['storeA'].replace(predicate, '{2}') + result['maskScatter'] = result['scatter'].replace(predicate, '{3}') + result['compile_flags'] = [f'-msve-vector-bits={bitwidth}'] else: result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})' diff --git a/pystencils/backends/x86_instruction_sets.py b/pystencils/backends/x86_instruction_sets.py index 50005c5ae..6bf12b448 100644 --- a/pystencils/backends/x86_instruction_sets.py +++ b/pystencils/backends/x86_instruction_sets.py @@ -130,7 +130,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): result['all'] = f"{pre}_movemask_{suf}({{0}}) == {hex(2**result['width']-1)}" if instruction_set == 'avx512': - size = 8 if data_type == 'double' else 16 + size = result['width'] result['&'] = f'_kand_mask{size}({{0}}, {{1}})' result['|'] = f'_kor_mask{size}({{0}}, {{1}})' result['any'] = f'!_ktestz_mask{size}_u8({{0}}, {{0}})' @@ -145,6 +145,14 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)]) result['makeVecConstBool'] = f"__mmask8(({params}) )" + vindex = f'{pre}_set_epi{bit_width//size}(' + ', '.join([str(i) for i in range(result['width'])][::-1]) + ')' + vindex = f'{pre}_mullo_epi{bit_width//size}({vindex}, {pre}_set1_epi{bit_width//size}({{0}}))' + result['scatter'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \ + f', {{1}}, {64//size})' + result['maskScatter'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \ + f', {{1}}, {64//size})' + result['gather'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {64//size})' + if instruction_set == 'avx' and data_type == 'float': result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})" -- GitLab