Skip to content
Snippets Groups Projects

Do not reorder accesses in `move_constants_before_loop` (quickly)

Merged Daniel Bauer requested to merge terraneo/pystencils:bauerd/move-constants-2 into master
1 file
+ 52
6
Compare changes
  • Side-by-side
  • Inline
@@ -4,9 +4,11 @@ import warnings
from collections import OrderedDict
from copy import deepcopy
from types import MappingProxyType
from typing import Set
import sympy as sp
import pystencils as ps
import pystencils.astnodes as ast
from pystencils.assignment import Assignment
from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type,
@@ -582,21 +584,65 @@ def move_constants_before_loop(ast_node):
"""
assert isinstance(node.parent, ast.Block)
def modifies_or_declares(node: ast.Node, symbol_names: Set[str]) -> bool:
if isinstance(node, (ps.Assignment, ast.SympyAssignment)):
if isinstance(node.lhs, ast.ResolvedFieldAccess):
return node.lhs.typed_symbol.name in symbol_names
else:
return node.lhs.name in symbol_names
elif isinstance(node, ast.Block):
for arg in node.args:
if isinstance(arg, ast.SympyAssignment) and arg.is_declaration:
continue
if modifies_or_declares(arg, symbol_names):
return True
return False
elif isinstance(node, ast.LoopOverCoordinate):
return modifies_or_declares(node.body, symbol_names)
elif isinstance(node, ast.Conditional):
return (
modifies_or_declares(node.true_block, symbol_names)
or (node.false_block and modifies_or_declares(node.false_block, symbol_names))
)
elif isinstance(node, ast.KernelFunction):
return False
else:
raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n'
f'The expression {node} of type {type(node)} is not known yet.')
dependencies = {s.name for s in node.undefined_symbols}
last_block = node.parent
last_block_child = node
element = node.parent
prev_element = node
while element:
if isinstance(element, ast.Block):
if isinstance(element, (ast.Conditional, ast.KernelFunction)):
# Never move out of Conditionals or KernelFunctions.
break
elif isinstance(element, ast.Block):
last_block = element
last_block_child = prev_element
if isinstance(element, ast.Conditional):
break
if any(modifies_or_declares(sibling, dependencies) for sibling in element.args):
# The node depends on one of the statements in this block.
# Do not move further out.
break
elif isinstance(element, ast.LoopOverCoordinate):
if element.loop_counter_symbol.name in dependencies:
# The node depends on the loop counter.
# Do not move out of this loop.
break
else:
critical_symbols = set([s.name for s in element.symbols_defined])
if set([s.name for s in node.undefined_symbols]).intersection(critical_symbols):
break
raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n'
f'The expression {element} of type {type(element)} is not known yet.')
# No dependencies to symbols defined/modified within the current element.
# We can move the node up one level and in front of the current element.
prev_element = element
element = element.parent
return last_block, last_block_child
Loading