Skip to content
Snippets Groups Projects
Commit bcb36f77 authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

only vectorize the RNG symbols if the loop was vectorized

parent f9a17154
No related branches found
No related tags found
1 merge request!248Fix RNG vectorization for LB
Pipeline #32347 passed
......@@ -126,7 +126,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
vector_width = vector_is['width']
kernel_ast.instruction_set = vector_is
vectorize_rng(kernel_ast, vector_width)
strided = 'storeS' in vector_is and 'loadS' in vector_is
keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU']
vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal,
......@@ -134,24 +133,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
insert_vector_casts(kernel_ast)
def vectorize_rng(kernel_ast, vector_width):
"""Replace scalar result symbols on RNG nodes with vectorial ones"""
from pystencils.rng import RNGBase
subst = {}
def visit_node(node):
for arg in node.args:
if isinstance(arg, RNGBase):
new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width))
for s in arg.result_symbols]
subst.update({s[0]: s[1] for s in zip(arg.result_symbols, new_result_symbols)})
arg._symbols_defined = set(new_result_symbols)
else:
visit_node(arg)
visit_node(kernel_ast)
fast_subs(kernel_ast.body, subst, skip=lambda e: isinstance(e, RNGBase))
def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields,
strided, keep_loop_stop, assume_sufficient_line_padding):
"""Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
......@@ -227,6 +208,15 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
mask_conditionals(loop_node)
from pystencils.rng import RNGBase
substitutions = {}
for rng in loop_node.atoms(RNGBase):
new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width))
for s in rng.result_symbols]
substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)})
rng._symbols_defined = set(new_result_symbols)
fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase))
def mask_conditionals(loop_body):
def visit_node(node, mask):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment