From bcb36f7752b633127a62ee3abd0b703056cfe33b Mon Sep 17 00:00:00 2001 From: Michael Kuron <m.kuron@gmx.de> Date: Thu, 27 May 2021 15:28:47 +0200 Subject: [PATCH] only vectorize the RNG symbols if the loop was vectorized --- pystencils/cpu/vectorization.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index e4a27ad34..b54b78d85 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -126,7 +126,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', vector_width = vector_is['width'] kernel_ast.instruction_set = vector_is - vectorize_rng(kernel_ast, vector_width) strided = 'storeS' in vector_is and 'loadS' in vector_is keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU'] vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal, @@ -134,24 +133,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', insert_vector_casts(kernel_ast) -def vectorize_rng(kernel_ast, vector_width): - """Replace scalar result symbols on RNG nodes with vectorial ones""" - from pystencils.rng import RNGBase - subst = {} - - def visit_node(node): - for arg in node.args: - if isinstance(arg, RNGBase): - new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width)) - for s in arg.result_symbols] - subst.update({s[0]: s[1] for s in zip(arg.result_symbols, new_result_symbols)}) - arg._symbols_defined = set(new_result_symbols) - else: - visit_node(arg) - visit_node(kernel_ast) - fast_subs(kernel_ast.body, subst, skip=lambda e: isinstance(e, RNGBase)) - - def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields, strided, keep_loop_stop, assume_sufficient_line_padding): """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type.""" @@ -227,6 +208,15 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a mask_conditionals(loop_node) + from pystencils.rng import RNGBase + substitutions = {} + for rng in loop_node.atoms(RNGBase): + new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width)) + for s in rng.result_symbols] + substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)}) + rng._symbols_defined = set(new_result_symbols) + fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase)) + def mask_conditionals(loop_body): def visit_node(node, mask): -- GitLab