Skip to content
Snippets Groups Projects
Commit 6fcc0ccc authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

Rename scatter/gather to strided

Some instruction sets have separate strided and scatter/gather operations, e.g. RISC-V-V or NEC SX
parent 0ee06d2d
No related branches found
No related tags found
1 merge request!234Sizeless vectorization
Pipeline #32226 passed
...@@ -102,9 +102,9 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -102,9 +102,9 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})' result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})'
vindex = f'svindex_u{bits[data_type]}(0, {{0}})' vindex = f'svindex_u{bits[data_type]}(0, {{0}})'
result['scatter'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \ result['storeS'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{2}") + ', {1})' vindex.format("{2}") + ', {1})'
result['gather'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \ result['loadS'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{1}") + ')' vindex.format("{1}") + ')'
result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})" result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})"
...@@ -124,7 +124,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -124,7 +124,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['maskStoreU'] = result['storeU'].replace(predicate, '{2}') result['maskStoreU'] = result['storeU'].replace(predicate, '{2}')
result['maskStoreA'] = result['storeA'].replace(predicate, '{2}') result['maskStoreA'] = result['storeA'].replace(predicate, '{2}')
result['maskScatter'] = result['scatter'].replace(predicate, '{3}') result['maskStoreS'] = result['storeS'].replace(predicate, '{3}')
if instruction_set != 'sve': if instruction_set != 'sve':
result['compile_flags'] = [f'-msve-vector-bits={bitwidth}'] result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
......
...@@ -311,7 +311,7 @@ class CBackend: ...@@ -311,7 +311,7 @@ class CBackend:
ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0]) ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0])
if stride != 1: if stride != 1:
instr = 'maskScatter' if mask != True else 'scatter' # NOQA instr = 'maskStoreS' if mask != True else 'storeS' # NOQA
return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs), return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
stride, printed_mask, **self._kwargs) + ';' stride, printed_mask, **self._kwargs) + ';'
...@@ -648,7 +648,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -648,7 +648,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if isinstance(expr, vector_memory_access): if isinstance(expr, vector_memory_access):
arg, data_type, aligned, _, mask, stride = expr.args arg, data_type, aligned, _, mask, stride = expr.args
if stride != 1: if stride != 1:
return self.instruction_set['gather'].format("& " + self._print(arg), stride, **self._kwargs) return self.instruction_set['loadS'].format("& " + self._print(arg), stride, **self._kwargs)
instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
return instruction.format("& " + self._print(arg), **self._kwargs) return instruction.format("& " + self._print(arg), **self._kwargs)
elif isinstance(expr, cast_func): elif isinstance(expr, cast_func):
......
...@@ -35,9 +35,9 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'): ...@@ -35,9 +35,9 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'):
'storeA': f'se{bits[data_type]}_v[0, 1]', 'storeA': f'se{bits[data_type]}_v[0, 1]',
'maskStoreU': f'se{bits[data_type]}_v[2, 0, 1]', 'maskStoreU': f'se{bits[data_type]}_v[2, 0, 1]',
'maskStoreA': f'se{bits[data_type]}_v[2, 0, 1]', 'maskStoreA': f'se{bits[data_type]}_v[2, 0, 1]',
'gather': f'lse{bits[data_type]}_v[0, 1]', 'loadS': f'lse{bits[data_type]}_v[0, 1]',
'scatter': f'sse{bits[data_type]}_v[0, 2, 1]', 'storeS': f'sse{bits[data_type]}_v[0, 2, 1]',
'maskScatter': f'sse{bits[data_type]}_v[2, 0, 3, 1]', 'maskStoreS': f'sse{bits[data_type]}_v[2, 0, 3, 1]',
'abs': 'fabs_v[0]', 'abs': 'fabs_v[0]',
'==': 'mfeq_vv[0, 1]', '==': 'mfeq_vv[0, 1]',
...@@ -90,9 +90,9 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'): ...@@ -90,9 +90,9 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'):
result['makeVecIndex'] = f'vmacc_vx_i{bits["int"]}m1({result["makeVecConstInt"]}, {{1}}, ' + \ result['makeVecIndex'] = f'vmacc_vx_i{bits["int"]}m1({result["makeVecConstInt"]}, {{1}}, ' + \
f'vid_v_i{bits["int"]}m1({int_vl}), {int_vl})' f'vid_v_i{bits["int"]}m1({int_vl}), {int_vl})'
result['scatter'] = result['scatter'].replace('{2}', f'{{2}}*{bits[data_type]//8}') result['storeS'] = result['storeS'].replace('{2}', f'{{2}}*{bits[data_type]//8}')
result['gather'] = result['gather'].replace('{1}', f'{{1}}*{bits[data_type]//8}') result['loadS'] = result['loadS'].replace('{1}', f'{{1}}*{bits[data_type]//8}')
result['maskScatter'] = result['maskScatter'].replace('{3}', f'{{3}}*{bits[data_type]//8}') result['maskStoreS'] = result['maskStoreS'].replace('{3}', f'{{3}}*{bits[data_type]//8}')
result['+int'] = f"vadd_vv_i{bits['int']}m1({{0}}, {{1}}, {int_vl})" result['+int'] = f"vadd_vv_i{bits['int']}m1({{0}}, {{1}}, {int_vl})"
......
...@@ -147,11 +147,11 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): ...@@ -147,11 +147,11 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
vindex = f'{pre}_set_epi{bit_width//size}(' + ', '.join([str(i) for i in range(result['width'])][::-1]) + ')' vindex = f'{pre}_set_epi{bit_width//size}(' + ', '.join([str(i) for i in range(result['width'])][::-1]) + ')'
vindex = f'{pre}_mullo_epi{bit_width//size}({vindex}, {pre}_set1_epi{bit_width//size}({{0}}))' vindex = f'{pre}_mullo_epi{bit_width//size}({vindex}, {pre}_set1_epi{bit_width//size}({{0}}))'
result['scatter'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \ result['storeS'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \
f', {{1}}, {64//size})' f', {{1}}, {64//size})'
result['maskScatter'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \ result['maskStoreS'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \
f', {{1}}, {64//size})' f', {{1}}, {64//size})'
result['gather'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {64//size})' result['loadS'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {64//size})'
if instruction_set == 'avx' and data_type == 'float': if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})" result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})"
......
...@@ -127,10 +127,10 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', ...@@ -127,10 +127,10 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
kernel_ast.instruction_set = vector_is kernel_ast.instruction_set = vector_is
vectorize_rng(kernel_ast, vector_width) vectorize_rng(kernel_ast, vector_width)
scattergather = 'scatter' in vector_is and 'gather' in vector_is strided = 'storeS' in vector_is and 'loadS' in vector_is
keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU'] keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU']
vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal, vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal,
scattergather, keep_loop_stop, assume_sufficient_line_padding) strided, keep_loop_stop, assume_sufficient_line_padding)
insert_vector_casts(kernel_ast) insert_vector_casts(kernel_ast)
...@@ -153,7 +153,7 @@ def vectorize_rng(kernel_ast, vector_width): ...@@ -153,7 +153,7 @@ def vectorize_rng(kernel_ast, vector_width):
def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields, def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields,
scattergather, keep_loop_stop, assume_sufficient_line_padding): strided, keep_loop_stop, assume_sufficient_line_padding):
"""Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type.""" """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment) all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
inner_loops = [n for n in all_loops if n.is_innermost_loop] inner_loops = [n for n in all_loops if n.is_innermost_loop]
...@@ -187,7 +187,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a ...@@ -187,7 +187,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms() loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms()
aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0 aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0
stride = sp.simplify(index.subs({loop_counter_symbol: loop_counter_symbol + 1}) - index) stride = sp.simplify(index.subs({loop_counter_symbol: loop_counter_symbol + 1}) - index)
if not loop_counter_is_offset and (not scattergather or loop_counter_symbol in stride.atoms()): if not loop_counter_is_offset and (not strided or loop_counter_symbol in stride.atoms()):
successful = False successful = False
break break
typed_symbol = base.label typed_symbol = base.label
...@@ -200,7 +200,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a ...@@ -200,7 +200,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
if hasattr(indexed, 'field'): if hasattr(indexed, 'field'):
nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields) nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields)
substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True, substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True,
stride if scattergather else 1) stride if strided else 1)
if nontemporal: if nontemporal:
# insert NontemporalFence after the outermost loop # insert NontemporalFence after the outermost loop
parent = loop_node.parent parent = loop_node.parent
......
...@@ -55,10 +55,10 @@ def test_vectorized_abs(instruction_set, dtype): ...@@ -55,10 +55,10 @@ def test_vectorized_abs(instruction_set, dtype):
@pytest.mark.parametrize('dtype', ('float', 'double')) @pytest.mark.parametrize('dtype', ('float', 'double'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets) @pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_scatter_gather(instruction_set, dtype): def test_strided(instruction_set, dtype):
f, g = ps.fields(f"f, g : float{64 if dtype == 'double' else 32}[2D]") f, g = ps.fields(f"f, g : float{64 if dtype == 'double' else 32}[2D]")
update_rule = [ps.Assignment(g[0, 0], f[0, 0] + f[-1, 0] + f[1, 0] + f[0, 1] + f[0, -1] + 42.0)] update_rule = [ps.Assignment(g[0, 0], f[0, 0] + f[-1, 0] + f[1, 0] + f[0, 1] + f[0, -1] + 42.0)]
if 'scatter' not in get_vector_instruction_set(dtype, instruction_set) and not instruction_set in ['avx512', 'rvv'] and not instruction_set.startswith('sve'): if 'storeS' not in get_vector_instruction_set(dtype, instruction_set) and not instruction_set in ['avx512', 'rvv'] and not instruction_set.startswith('sve'):
with pytest.warns(UserWarning) as warn: with pytest.warns(UserWarning) as warn:
ast = ps.create_kernel(update_rule, cpu_vectorize_info={'instruction_set': instruction_set}) ast = ps.create_kernel(update_rule, cpu_vectorize_info={'instruction_set': instruction_set})
assert 'Could not vectorize loop' in warn[0].message.args[0] assert 'Could not vectorize loop' in warn[0].message.args[0]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment