diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 02806d62223c8443cc2bb91740afd2c0ef6bda55..4e3d3862c874df879ec5c021e429cc14709f75e9 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -4,9 +4,11 @@ import warnings from collections import OrderedDict from copy import deepcopy from types import MappingProxyType +from typing import Set import sympy as sp +import pystencils as ps import pystencils.astnodes as ast from pystencils.assignment import Assignment from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type, @@ -582,21 +584,65 @@ def move_constants_before_loop(ast_node): """ assert isinstance(node.parent, ast.Block) + def modifies_or_declares(node: ast.Node, symbol_names: Set[str]) -> bool: + if isinstance(node, (ps.Assignment, ast.SympyAssignment)): + if isinstance(node.lhs, ast.ResolvedFieldAccess): + return node.lhs.typed_symbol.name in symbol_names + else: + return node.lhs.name in symbol_names + elif isinstance(node, ast.Block): + for arg in node.args: + if isinstance(arg, ast.SympyAssignment) and arg.is_declaration: + continue + if modifies_or_declares(arg, symbol_names): + return True + return False + elif isinstance(node, ast.LoopOverCoordinate): + return modifies_or_declares(node.body, symbol_names) + elif isinstance(node, ast.Conditional): + return ( + modifies_or_declares(node.true_block, symbol_names) + or (node.false_block and modifies_or_declares(node.false_block, symbol_names)) + ) + elif isinstance(node, ast.KernelFunction): + return False + else: + raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n' + f'The expression {node} of type {type(node)} is not known yet.') + + dependencies = {s.name for s in node.undefined_symbols} + last_block = node.parent last_block_child = node element = node.parent prev_element = node + while element: - if isinstance(element, ast.Block): + if isinstance(element, (ast.Conditional, ast.KernelFunction)): + # Never move out of Conditionals or KernelFunctions. + break + + elif isinstance(element, ast.Block): last_block = element last_block_child = prev_element - if isinstance(element, ast.Conditional): - break + if any(modifies_or_declares(sibling, dependencies) for sibling in element.args): + # The node depends on one of the statements in this block. + # Do not move further out. + break + + elif isinstance(element, ast.LoopOverCoordinate): + if element.loop_counter_symbol.name in dependencies: + # The node depends on the loop counter. + # Do not move out of this loop. + break + else: - critical_symbols = set([s.name for s in element.symbols_defined]) - if set([s.name for s in node.undefined_symbols]).intersection(critical_symbols): - break + raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n' + f'The expression {element} of type {type(element)} is not known yet.') + + # No dependencies to symbols defined/modified within the current element. + # We can move the node up one level and in front of the current element. prev_element = element element = element.parent return last_block, last_block_child