Skip to content
Snippets Groups Projects
Commit 88170aa4 authored by Daniel Bauer's avatar Daniel Bauer :speech_balloon:
Browse files

do not move accesses before last modification

parent 994f6fcb
No related branches found
No related tags found
No related merge requests found
...@@ -4,9 +4,11 @@ import warnings ...@@ -4,9 +4,11 @@ import warnings
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from types import MappingProxyType from types import MappingProxyType
from typing import Set
import sympy as sp import sympy as sp
import pystencils as ps
import pystencils.astnodes as ast import pystencils.astnodes as ast
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type, from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type,
...@@ -582,21 +584,65 @@ def move_constants_before_loop(ast_node): ...@@ -582,21 +584,65 @@ def move_constants_before_loop(ast_node):
""" """
assert isinstance(node.parent, ast.Block) assert isinstance(node.parent, ast.Block)
def modifies_or_declares(node: ast.Node, symbol_names: Set[str]) -> bool:
if isinstance(node, (ps.Assignment, ast.SympyAssignment)):
if isinstance(node.lhs, ast.ResolvedFieldAccess):
return node.lhs.typed_symbol.name in symbol_names
else:
return node.lhs.name in symbol_names
elif isinstance(node, ast.Block):
for arg in node.args:
if isinstance(arg, ast.SympyAssignment) and arg.is_declaration:
continue
if modifies_or_declares(arg, symbol_names):
return True
return False
elif isinstance(node, ast.LoopOverCoordinate):
return modifies_or_declares(node.body, symbol_names)
elif isinstance(node, ast.Conditional):
return (
modifies_or_declares(node.true_block, symbol_names)
or (node.false_block and modifies_or_declares(node.false_block, symbol_names))
)
elif isinstance(node, ast.KernelFunction):
return False
else:
raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n'
f'The expression {node} of type {type(node)} is not known yet.')
dependencies = {s.name for s in node.undefined_symbols}
last_block = node.parent last_block = node.parent
last_block_child = node last_block_child = node
element = node.parent element = node.parent
prev_element = node prev_element = node
while element: while element:
if isinstance(element, ast.Block): if isinstance(element, (ast.Conditional, ast.KernelFunction)):
# Never move out of Conditionals or KernelFunctions.
break
elif isinstance(element, ast.Block):
last_block = element last_block = element
last_block_child = prev_element last_block_child = prev_element
if isinstance(element, ast.Conditional): if any(modifies_or_declares(sibling, dependencies) for sibling in element.args):
break # The node depends on one of the statements in this block.
# Do not move further out.
break
elif isinstance(element, ast.LoopOverCoordinate):
if element.loop_counter_symbol.name in dependencies:
# The node depends on the loop counter.
# Do not move out of this loop.
break
else: else:
critical_symbols = set([s.name for s in element.symbols_defined]) raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n'
if set([s.name for s in node.undefined_symbols]).intersection(critical_symbols): f'The expression {element} of type {type(element)} is not known yet.')
break
# No dependencies to symbols defined/modified within the current element.
# We can move the node up one level and in front of the current element.
prev_element = element prev_element = element
element = element.parent element = element.parent
return last_block, last_block_child return last_block, last_block_child
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment