From 6d08ea7a1089d3b27eb6ffd0220918d16ab6008f Mon Sep 17 00:00:00 2001 From: Michael Kuron <m.kuron@gmx.de> Date: Tue, 27 Apr 2021 22:53:17 +0200 Subject: [PATCH] determine stride for scatter/gather --- pystencils/backends/cbackend.py | 12 +++++++-- pystencils/cpu/vectorization.py | 18 +++++++------ pystencils/data_types.py | 4 +-- .../test_vectorization_specific.py | 25 +++++++++++++++++++ 4 files changed, 47 insertions(+), 12 deletions(-) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 988a8e518..71ef00afc 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -256,7 +256,7 @@ class CBackend: lhs_type = get_type_of_expression(node.lhs) printed_mask = "" if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): - arg, data_type, aligned, nontemporal, mask = node.lhs.args + arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args instr = 'storeU' if aligned: instr = 'stream' if nontemporal and 'stream' in self._vector_instruction_set else 'storeA' @@ -281,6 +281,12 @@ class CBackend: rhs = node.rhs ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0]) + + if stride != 1: + instr = 'maskScatter' if mask != True else 'scatter' # NOQA + return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs), + stride, printed_mask) + ';' + pre_code = '' if nontemporal and 'cachelineZero' in self._vector_instruction_set: pre_code = f"if (((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0) " + "{\n\t" + \ @@ -605,7 +611,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): def _print_Function(self, expr): if isinstance(expr, vector_memory_access): - arg, data_type, aligned, _, mask = expr.args + arg, data_type, aligned, _, mask, stride = expr.args + if stride != 1: + return self.instruction_set['gather'].format("& " + self._print(arg), stride) instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] return instruction.format("& " + self._print(arg)) elif isinstance(expr, cast_func): diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 0de34b40b..16f0a1563 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -80,8 +80,9 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', kernel_ast.instruction_set = vector_is vectorize_rng(kernel_ast, vector_width) - vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, - nontemporal, assume_sufficient_line_padding) + scattergather = 'scatter' in vector_is and 'gather' in vector_is + vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal, + scattergather, assume_sufficient_line_padding) insert_vector_casts(kernel_ast) @@ -104,7 +105,7 @@ def vectorize_rng(kernel_ast, vector_width): def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields, - assume_sufficient_line_padding): + scattergather, assume_sufficient_line_padding): """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type.""" all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment) inner_loops = [n for n in all_loops if n.is_innermost_loop] @@ -135,7 +136,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a if loop_counter_symbol in index.atoms(sp.Symbol): loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms() aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0 - if not loop_counter_is_offset: + stride = sp.simplify(index.subs({loop_counter_symbol: loop_counter_symbol + 1}) - index) + if not loop_counter_is_offset and (not scattergather or loop_counter_symbol in stride.atoms()): successful = False break typed_symbol = base.label @@ -147,7 +149,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a nontemporal = False if hasattr(indexed, 'field'): nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields) - substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True) + substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True, + stride if scattergather else 1) if nontemporal: # insert NontemporalFence after the outermost loop parent = loop_node.parent @@ -188,7 +191,7 @@ def mask_conditionals(loop_body): node.condition_expr = vec_any(node.condition_expr) elif isinstance(node, ast.SympyAssignment): if mask is not True: - s = {ma: vector_memory_access(ma.args[0], ma.args[1], ma.args[2], ma.args[3], sp.And(mask, ma.args[4])) + s = {ma: vector_memory_access(*ma.args[0:4], sp.And(mask, ma.args[4]), *ma.args[5:]) for ma in node.atoms(vector_memory_access)} node.subs(s) else: @@ -205,8 +208,7 @@ def insert_vector_casts(ast_node): def visit_expr(expr): if isinstance(expr, vector_memory_access): - return vector_memory_access(expr.args[0], expr.args[1], expr.args[2], expr.args[3], - visit_expr(expr.args[4])) + return vector_memory_access(*expr.args[0:4], visit_expr(expr.args[4]), *expr.args[5:]) elif isinstance(expr, cast_func): return expr elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set: diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 46abd84f3..baf0a9674 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -195,8 +195,8 @@ class boolean_cast_func(cast_func, Boolean): # noinspection PyPep8Naming class vector_memory_access(cast_func): - # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none) - nargs = (5,) + # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride + nargs = (6,) # noinspection PyPep8Naming diff --git a/pystencils_tests/test_vectorization_specific.py b/pystencils_tests/test_vectorization_specific.py index df0b9d943..f579b4e46 100644 --- a/pystencils_tests/test_vectorization_specific.py +++ b/pystencils_tests/test_vectorization_specific.py @@ -53,6 +53,31 @@ def test_vectorized_abs(instruction_set, dtype): np.testing.assert_equal(np.sum(dst[1:-1, 1:-1]), 2 ** 2 * 2 ** 3) +@pytest.mark.parametrize('dtype', ('float', 'double')) +@pytest.mark.parametrize('instruction_set', supported_instruction_sets) +def test_scatter_gather(instruction_set, dtype): + f, g = ps.fields(f"f, g : float{64 if dtype == 'double' else 32}[2D]") + update_rule = [ps.Assignment(g[0, 0], f[0, 0] + f[-1, 0] + f[1, 0] + f[0, 1] + f[0, -1] + 42.0)] + if 'scatter' not in get_vector_instruction_set(dtype, instruction_set) and not instruction_set == 'avx512' and not instruction_set.startswith('sve'): + with pytest.warns(UserWarning) as warn: + ast = ps.create_kernel(update_rule, cpu_vectorize_info={'instruction_set': instruction_set}) + assert 'Could not vectorize loop' in warn[0].message.args[0] + else: + with pytest.warns(None) as warn: + ast = ps.create_kernel(update_rule, cpu_vectorize_info={'instruction_set': instruction_set}) + assert len(warn) == 0 + func = ast.compile() + ref_func = ps.create_kernel(update_rule).compile() + + arr = np.random.random((23 + 2, 17 + 2)).astype(np.float64 if dtype == 'double' else np.float32) + dst = np.zeros_like(arr, dtype=np.float64 if dtype == 'double' else np.float32) + ref = np.zeros_like(arr, dtype=np.float64 if dtype == 'double' else np.float32) + + func(g=dst, f=arr) + ref_func(g=ref, f=arr) + np.testing.assert_almost_equal(dst, ref, 13 if dtype == 'double' else 5) + + @pytest.mark.parametrize('dtype', ('float', 'double')) @pytest.mark.parametrize('instruction_set', supported_instruction_sets) @pytest.mark.parametrize('gl_field, gl_kernel', [(1, 0), (0, 1), (1, 1)]) -- GitLab