Skip to content
Snippets Groups Projects
Commit 64d39897 authored by Daniel Bauer's avatar Daniel Bauer :speech_balloon: Committed by Frederik Hennig
Browse files

Do not hoist declarations of mutated variables

parent 4b2de595
No related branches found
No related tags found
No related merge requests found
...@@ -27,6 +27,7 @@ class HoistContext: ...@@ -27,6 +27,7 @@ class HoistContext:
def __init__(self) -> None: def __init__(self) -> None:
self.hoisted_nodes: list[PsDeclaration] = [] self.hoisted_nodes: list[PsDeclaration] = []
self.assigned_symbols: set[PsSymbol] = set() self.assigned_symbols: set[PsSymbol] = set()
self.mutated_symbols: set[PsSymbol] = set()
self.invariant_symbols: set[PsSymbol] = set() self.invariant_symbols: set[PsSymbol] = set()
def _is_invariant(self, expr: PsExpression) -> bool: def _is_invariant(self, expr: PsExpression) -> bool:
...@@ -123,6 +124,7 @@ class HoistLoopInvariantDeclarations: ...@@ -123,6 +124,7 @@ class HoistLoopInvariantDeclarations:
"""Hoist invariant declarations out of the given loop.""" """Hoist invariant declarations out of the given loop."""
hc = HoistContext() hc = HoistContext()
hc.assigned_symbols.add(loop.counter.symbol) hc.assigned_symbols.add(loop.counter.symbol)
hc.mutated_symbols.add(loop.counter.symbol)
self._prepare_hoist(loop.body, hc) self._prepare_hoist(loop.body, hc)
self._hoist_from_block(loop.body, hc) self._hoist_from_block(loop.body, hc)
return hc return hc
...@@ -134,8 +136,12 @@ class HoistLoopInvariantDeclarations: ...@@ -134,8 +136,12 @@ class HoistLoopInvariantDeclarations:
case PsExpression(): case PsExpression():
return return
case PsDeclaration(PsSymbolExpr(lhs_symb), _):
hc.assigned_symbols.add(lhs_symb)
case PsAssignment(PsSymbolExpr(lhs_symb), _): case PsAssignment(PsSymbolExpr(lhs_symb), _):
hc.assigned_symbols.add(lhs_symb) hc.assigned_symbols.add(lhs_symb)
hc.mutated_symbols.add(lhs_symb)
case PsAssignment(_, _): case PsAssignment(_, _):
return return
...@@ -147,6 +153,7 @@ class HoistLoopInvariantDeclarations: ...@@ -147,6 +153,7 @@ class HoistLoopInvariantDeclarations:
loop = stmt loop = stmt
nested_hc = self._hoist(loop) nested_hc = self._hoist(loop)
hc.assigned_symbols |= nested_hc.assigned_symbols hc.assigned_symbols |= nested_hc.assigned_symbols
hc.mutated_symbols |= nested_hc.mutated_symbols
statements_new += nested_hc.hoisted_nodes statements_new += nested_hc.hoisted_nodes
if loop.body.statements: if loop.body.statements:
statements_new.append(loop) statements_new.append(loop)
...@@ -169,7 +176,8 @@ class HoistLoopInvariantDeclarations: ...@@ -169,7 +176,8 @@ class HoistLoopInvariantDeclarations:
for node in block.statements: for node in block.statements:
if isinstance(node, PsDeclaration): if isinstance(node, PsDeclaration):
if hc._is_invariant(node.rhs): lhs_symb = cast(PsSymbolExpr, node.lhs).symbol
if lhs_symb not in hc.mutated_symbols and hc._is_invariant(node.rhs):
hc.hoisted_nodes.append(node) hc.hoisted_nodes.append(node)
hc.invariant_symbols.add(node.declared_symbol) hc.invariant_symbols.add(node.declared_symbol)
else: else:
......
...@@ -193,3 +193,22 @@ def test_hoisting_eliminates_loops(): ...@@ -193,3 +193,22 @@ def test_hoisting_eliminates_loops():
assert isinstance(ast, PsBlock) assert isinstance(ast, PsBlock)
# All statements are hoisted and the loops are removed # All statements are hoisted and the loops are removed
assert ast.statements == invariant_decls assert ast.statements == invariant_decls
def test_hoist_mutation():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
hoist = HoistLoopInvariantDeclarations(ctx)
x = sp.Symbol("x")
x_decl = factory.parse_sympy(Assignment(x, 1))
x_update = factory.parse_sympy(AddAugmentedAssignment(x, 1))
inner_loop = factory.loop("j", slice(10), PsBlock([x_update]))
outer_loop = factory.loop("i", slice(10), PsBlock([x_decl, inner_loop]))
result = hoist(outer_loop)
# x is updated in the loop, so nothing can be hoisted
assert isinstance(result, PsLoop)
assert result.body.statements == [x_decl, inner_loop]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment