Skip to content
Snippets Groups Projects

Do not hoist declarations of mutated variables

Merged Daniel Bauer requested to merge hyteg/pystencils:bauerd/fix-hoisting into backend-rework
Compare and
2 files
+ 28
1
Preferences
Compare changes
Files
2
@@ -27,6 +27,7 @@ class HoistContext:
def __init__(self) -> None:
self.hoisted_nodes: list[PsDeclaration] = []
self.assigned_symbols: set[PsSymbol] = set()
self.mutated_symbols: set[PsSymbol] = set()
self.invariant_symbols: set[PsSymbol] = set()
def _is_invariant(self, expr: PsExpression) -> bool:
@@ -123,6 +124,7 @@ class HoistLoopInvariantDeclarations:
"""Hoist invariant declarations out of the given loop."""
hc = HoistContext()
hc.assigned_symbols.add(loop.counter.symbol)
hc.mutated_symbols.add(loop.counter.symbol)
self._prepare_hoist(loop.body, hc)
self._hoist_from_block(loop.body, hc)
return hc
@@ -134,8 +136,12 @@ class HoistLoopInvariantDeclarations:
case PsExpression():
return
case PsDeclaration(PsSymbolExpr(lhs_symb), _):
hc.assigned_symbols.add(lhs_symb)
case PsAssignment(PsSymbolExpr(lhs_symb), _):
hc.assigned_symbols.add(lhs_symb)
hc.mutated_symbols.add(lhs_symb)
case PsAssignment(_, _):
return
@@ -147,6 +153,7 @@ class HoistLoopInvariantDeclarations:
loop = stmt
nested_hc = self._hoist(loop)
hc.assigned_symbols |= nested_hc.assigned_symbols
hc.mutated_symbols |= nested_hc.mutated_symbols
statements_new += nested_hc.hoisted_nodes
if loop.body.statements:
statements_new.append(loop)
@@ -169,7 +176,8 @@ class HoistLoopInvariantDeclarations:
for node in block.statements:
if isinstance(node, PsDeclaration):
if hc._is_invariant(node.rhs):
lhs_symb = cast(PsSymbolExpr, node.lhs).symbol
if lhs_symb not in hc.mutated_symbols and hc._is_invariant(node.rhs):
hc.hoisted_nodes.append(node)
hc.invariant_symbols.add(node.declared_symbol)
else: