Skip to content
Snippets Groups Projects
Commit 6fa41f7c authored by Michael Kuron's avatar Michael Kuron :mortar_board: Committed by Markus Holzer
Browse files

Fix RNG vectorization for LB

parent 584b4255
Branches
No related tags found
No related merge requests found
...@@ -192,7 +192,9 @@ class CBackend: ...@@ -192,7 +192,9 @@ class CBackend:
def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'): def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
if sympy_printer is None: if sympy_printer is None:
if vector_instruction_set is not None: if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set) self.vector_sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
self.scalar_sympy_printer = CustomSympyPrinter()
self.sympy_printer = self.vector_sympy_printer
else: else:
self.sympy_printer = CustomSympyPrinter() self.sympy_printer = CustomSympyPrinter()
else: else:
...@@ -259,6 +261,12 @@ class CBackend: ...@@ -259,6 +261,12 @@ class CBackend:
prefix = "\n".join(node.prefix_lines) prefix = "\n".join(node.prefix_lines)
if prefix: if prefix:
prefix += "\n" prefix += "\n"
if self._vector_instruction_set and hasattr(node, 'instruction_set') and node.instruction_set is None:
# the tail loop must not be vectorized
self.sympy_printer = self.scalar_sympy_printer
code = f"{prefix}{loop_str}\n{self._print(node.body)}"
self.sympy_printer = self.vector_sympy_printer
return code
return f"{prefix}{loop_str}\n{self._print(node.body)}" return f"{prefix}{loop_str}\n{self._print(node.body)}"
def _print_SympyAssignment(self, node): def _print_SympyAssignment(self, node):
...@@ -670,7 +678,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -670,7 +678,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
else: else:
is_boolean = get_type_of_expression(arg) == create_type("bool") is_boolean = get_type_of_expression(arg) == create_type("bool")
is_integer = get_type_of_expression(arg) == create_type("int") or \ is_integer = get_type_of_expression(arg) == create_type("int") or \
(isinstance(arg, TypedSymbol) and arg.dtype.is_int()) (isinstance(arg, TypedSymbol) and not isinstance(arg.dtype, VectorType) and arg.dtype.is_int())
instruction = 'makeVecConstBool' if is_boolean else \ instruction = 'makeVecConstBool' if is_boolean else \
'makeVecConstInt' if is_integer else 'makeVecConst' 'makeVecConstInt' if is_integer else 'makeVecConst'
return self.instruction_set[instruction].format(self._print(arg), **self._kwargs) return self.instruction_set[instruction].format(self._print(arg), **self._kwargs)
......
...@@ -126,7 +126,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', ...@@ -126,7 +126,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
vector_width = vector_is['width'] vector_width = vector_is['width']
kernel_ast.instruction_set = vector_is kernel_ast.instruction_set = vector_is
vectorize_rng(kernel_ast, vector_width)
strided = 'storeS' in vector_is and 'loadS' in vector_is strided = 'storeS' in vector_is and 'loadS' in vector_is
keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU'] keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU']
vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal, vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal,
...@@ -134,24 +133,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', ...@@ -134,24 +133,6 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
insert_vector_casts(kernel_ast) insert_vector_casts(kernel_ast)
def vectorize_rng(kernel_ast, vector_width):
"""Replace scalar result symbols on RNG nodes with vectorial ones"""
from pystencils.rng import RNGBase
subst = {}
def visit_node(node):
for arg in node.args:
if isinstance(arg, RNGBase):
new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width))
for s in arg.result_symbols]
subst.update({s[0]: s[1] for s in zip(arg.result_symbols, new_result_symbols)})
arg._symbols_defined = set(new_result_symbols)
else:
visit_node(arg)
visit_node(kernel_ast)
fast_subs(kernel_ast.body, subst, skip=lambda e: isinstance(e, RNGBase))
def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields, def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields,
strided, keep_loop_stop, assume_sufficient_line_padding): strided, keep_loop_stop, assume_sufficient_line_padding):
"""Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type.""" """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
...@@ -173,6 +154,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a ...@@ -173,6 +154,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start
loop_nodes = [l for l in cut_loop(loop_node, [cutting_point]).args if isinstance(l, ast.LoopOverCoordinate)] loop_nodes = [l for l in cut_loop(loop_node, [cutting_point]).args if isinstance(l, ast.LoopOverCoordinate)]
assert len(loop_nodes) in (0, 1, 2) # 2 for main and tail loop, 1 if loop range divisible by vector width assert len(loop_nodes) in (0, 1, 2) # 2 for main and tail loop, 1 if loop range divisible by vector width
if len(loop_nodes) == 2:
loop_nodes[1].instruction_set = None
if len(loop_nodes) == 0: if len(loop_nodes) == 0:
continue continue
loop_node = loop_nodes[0] loop_node = loop_nodes[0]
...@@ -225,6 +208,15 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a ...@@ -225,6 +208,15 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
mask_conditionals(loop_node) mask_conditionals(loop_node)
from pystencils.rng import RNGBase
substitutions = {}
for rng in loop_node.atoms(RNGBase):
new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width))
for s in rng.result_symbols]
substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)})
rng._symbols_defined = set(new_result_symbols)
fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase))
def mask_conditionals(loop_body): def mask_conditionals(loop_body):
def visit_node(node, mask): def visit_node(node, mask):
...@@ -322,6 +314,9 @@ def insert_vector_casts(ast_node): ...@@ -322,6 +314,9 @@ def insert_vector_casts(ast_node):
return expr return expr
def visit_node(node, substitution_dict): def visit_node(node, substitution_dict):
if hasattr(node, 'instruction_set') and node.instruction_set is None:
# the tail loop must not be vectorized
return
substitution_dict = substitution_dict.copy() substitution_dict = substitution_dict.copy()
for arg in node.args: for arg in node.args:
if isinstance(arg, ast.SympyAssignment): if isinstance(arg, ast.SympyAssignment):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment