Skip to content
Snippets Groups Projects
Commit 59f06397 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

reworked tree and visitors:

 - Introduce type-based dispatch
 - collect children in base class
parent 5635e7e9
Branches
Tags
No related merge requests found
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Sequence, Set from typing import TYPE_CHECKING, Sequence, Set, Tuple
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from itertools import chain from itertools import chain
...@@ -15,15 +15,27 @@ if TYPE_CHECKING: ...@@ -15,15 +15,27 @@ if TYPE_CHECKING:
class SfgCallTreeNode(ABC): class SfgCallTreeNode(ABC):
"""Base class for all nodes comprising SFG call trees. """ """Base class for all nodes comprising SFG call trees. """
def __init__(self, *children: SfgCallTreeNode):
self._children = children
@property @property
@abstractmethod def children(self) -> Tuple[SfgCallTreeNode]:
def children(self) -> Sequence[SfgCallTreeNode]: return self._children
pass
@abstractmethod def child(self, idx: int) -> SfgCallTreeNode:
def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None: return self._children[idx]
pass
@children.setter
def children(self, cs: Sequence[SfgCallTreeNode]) -> None:
if len(cs) != len(self._children):
raise ValueError("The number of child nodes must remain the same!")
self._children = list(cs)
def __getitem__(self, idx: int) -> SfgCallTreeNode:
return self._children[idx]
def __setitem__(self, idx: int, c: SfgCallTreeNode) -> None:
self._children[idx] = c
@abstractmethod @abstractmethod
def get_code(self, ctx: SfgContext) -> str: def get_code(self, ctx: SfgContext) -> str:
...@@ -40,13 +52,6 @@ class SfgCallTreeNode(ABC): ...@@ -40,13 +52,6 @@ class SfgCallTreeNode(ABC):
class SfgCallTreeLeaf(SfgCallTreeNode, ABC): class SfgCallTreeLeaf(SfgCallTreeNode, ABC):
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return ()
def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None:
raise SfgException("Leaf nodes have no children.")
@property @property
@abstractmethod @abstractmethod
def required_parameters(self) -> Set[TypedSymbolOrObject]: def required_parameters(self) -> Set[TypedSymbolOrObject]:
...@@ -74,6 +79,8 @@ class SfgStatements(SfgCallTreeLeaf): ...@@ -74,6 +79,8 @@ class SfgStatements(SfgCallTreeLeaf):
code_string: str, code_string: str,
defined_params: Sequence[TypedSymbolOrObject], defined_params: Sequence[TypedSymbolOrObject],
required_params: Sequence[TypedSymbolOrObject]): required_params: Sequence[TypedSymbolOrObject]):
super().__init__()
self._code_string = code_string self._code_string = code_string
self._defined_params = set(defined_params) self._defined_params = set(defined_params)
...@@ -102,14 +109,7 @@ class SfgStatements(SfgCallTreeLeaf): ...@@ -102,14 +109,7 @@ class SfgStatements(SfgCallTreeLeaf):
class SfgSequence(SfgCallTreeNode): class SfgSequence(SfgCallTreeNode):
def __init__(self, children: Sequence[SfgCallTreeNode]): def __init__(self, children: Sequence[SfgCallTreeNode]):
self._children = list(children) super().__init__(*children)
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return self._children
def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None:
self._children[child_idx] = node
def get_code(self, ctx: SfgContext) -> str: def get_code(self, ctx: SfgContext) -> str:
return "\n".join(c.get_code(ctx) for c in self._children) return "\n".join(c.get_code(ctx) for c in self._children)
...@@ -117,17 +117,11 @@ class SfgSequence(SfgCallTreeNode): ...@@ -117,17 +117,11 @@ class SfgSequence(SfgCallTreeNode):
class SfgBlock(SfgCallTreeNode): class SfgBlock(SfgCallTreeNode):
def __init__(self, subtree: SfgCallTreeNode): def __init__(self, subtree: SfgCallTreeNode):
super().__init__() super().__init__(subtree)
self._subtree = subtree
@property @property
def children(self) -> Sequence[SfgCallTreeNode]: def subtree(self) -> SfgCallTreeNode:
return [self._subtree] return self._children[0]
def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None:
match child_idx:
case 0: self._subtree = node
case _: raise IndexError(f"Invalid child index: {child_idx}. SfgBlock has only a single child.")
def get_code(self, ctx: SfgContext) -> str: def get_code(self, ctx: SfgContext) -> str:
subtree_code = ctx.codestyle.indent(self._subtree.get_code(ctx)) subtree_code = ctx.codestyle.indent(self._subtree.get_code(ctx))
...@@ -137,6 +131,7 @@ class SfgBlock(SfgCallTreeNode): ...@@ -137,6 +131,7 @@ class SfgBlock(SfgCallTreeNode):
class SfgKernelCallNode(SfgCallTreeLeaf): class SfgKernelCallNode(SfgCallTreeLeaf):
def __init__(self, kernel_handle: SfgKernelHandle): def __init__(self, kernel_handle: SfgKernelHandle):
super().__init__()
self._kernel_handle = kernel_handle self._kernel_handle = kernel_handle
@property @property
......
...@@ -14,8 +14,10 @@ class SfgCondition(SfgCallTreeLeaf): ...@@ -14,8 +14,10 @@ class SfgCondition(SfgCallTreeLeaf):
class SfgCustomCondition(SfgCondition): class SfgCustomCondition(SfgCondition):
def __init__(self, cond_text: str): def __init__(self, cond_text: str):
super().__init__()
self._cond_text = cond_text self._cond_text = cond_text
@property
def required_parameters(self) -> Set[TypedSymbolOrObject]: def required_parameters(self) -> Set[TypedSymbolOrObject]:
return set() return set()
...@@ -32,31 +34,28 @@ class SfgBranch(SfgCallTreeNode): ...@@ -32,31 +34,28 @@ class SfgBranch(SfgCallTreeNode):
cond: SfgCondition, cond: SfgCondition,
branch_true: SfgCallTreeNode, branch_true: SfgCallTreeNode,
branch_false: Optional[SfgCallTreeNode] = None): branch_false: Optional[SfgCallTreeNode] = None):
self._cond = cond super().__init__(cond, branch_true, *((branch_false,) if branch_false else ()))
self._branch_true = branch_true
self._branch_false = branch_false @property
def condition(self) -> SfgCondition:
return self._children[0]
@property
def branch_true(self) -> SfgCallTreeNode:
return self._children[1]
@property @property
def children(self) -> Sequence[SfgCallTreeNode]: def branch_false(self) -> SfgCallTreeNode:
if self._branch_false is not None: return self._children[2]
return (self._branch_true, self._branch_false)
else:
return (self._branch_true,)
def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None:
match child_idx:
case 0: self._branch_true = node
case 1: self._branch_false = node
case _: raise IndexError(f"Invalid child index: {child_idx}. SfgBlock has only two children.")
def get_code(self, ctx: SfgContext) -> str: def get_code(self, ctx: SfgContext) -> str:
code = f"if({self._cond.get_code(ctx)}) {{\n" code = f"if({self.condition.get_code(ctx)}) {{\n"
code += ctx.codestyle.indent(self._branch_true.get_code(ctx)) code += ctx.codestyle.indent(self.branch_true.get_code(ctx))
code += "\n}" code += "\n}"
if self._branch_false is not None: if self.branch_false is not None:
code += "else {\n" code += "else {\n"
code += ctx.codestyle.indent(self._branch_false.get_code(ctx)) code += ctx.codestyle.indent(self.branch_false.get_code(ctx))
code += "\n}" code += "\n}"
return code return code
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Sequence, Set from typing import TYPE_CHECKING, Set
if TYPE_CHECKING: if TYPE_CHECKING:
from ..context import SfgContext from ..context import SfgContext
...@@ -25,15 +25,14 @@ class SfgDeferredNode(SfgCallTreeNode, ABC): ...@@ -25,15 +25,14 @@ class SfgDeferredNode(SfgCallTreeNode, ABC):
because information required for their construction is not yet known. because information required for their construction is not yet known.
""" """
@property class InvalidAccess:
def children(self) -> Sequence[SfgCallTreeNode]: def __get__(self):
raise SfgException("Deferred nodes cannot be descended into; expand it first.") raise SfgException("Invalid access into deferred node; deferred nodes must be expanded first.")
def replace_child(self, child_idx: int, node: SfgCallTreeNode) -> None: def __init__(self):
raise SfgException("Deferred nodes do not have children.") self._children = SfgDeferredNode.InvalidAccess
def get_code(self, ctx: SfgContext) -> str: get_code = InvalidAccess
raise SfgException("Deferred nodes can not generate code; they need to be expanded first.")
@abstractmethod @abstractmethod
def expand(self, ctx: SfgContext, *args, **kwargs) -> SfgCallTreeNode: def expand(self, ctx: SfgContext, *args, **kwargs) -> SfgCallTreeNode:
......
from __future__ import annotations
from typing import Callable
from types import MethodType
from functools import wraps
from .basic_nodes import SfgCallTreeNode
class VisitorDispatcher:
def __init__(self, wrapped_method):
self._dispatch_dict = {}
self._wrapped_method = wrapped_method
def case(self, node_type: type):
"""Decorator for visitor's methods"""
def decorate(handler: Callable):
if node_type in self._dispatch_dict:
raise ValueError(f"Duplicate visitor case {node_type}")
self._dispatch_dict[node_type] = handler
return handler
return decorate
def __call__(self, instance, node: SfgCallTreeNode, *args, **kwargs):
for cls in node.__class__.mro():
if cls in self._dispatch_dict:
return self._dispatch_dict[cls](instance, node, *args, **kwargs)
return self._wrapped_method(instance, node, *args, **kwargs)
def __get__(self, obj, objtype=None):
if obj is None:
return self
return MethodType(self, obj)
def visitor(method):
return wraps(method)(VisitorDispatcher(method))
...@@ -8,7 +8,7 @@ from pystencils.typing import TypedSymbol ...@@ -8,7 +8,7 @@ from pystencils.typing import TypedSymbol
from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements
from .deferred_nodes import SfgParamCollectionDeferredNode from .deferred_nodes import SfgParamCollectionDeferredNode
from .dispatcher import visitor
if TYPE_CHECKING: if TYPE_CHECKING:
from ..context import SfgContext from ..context import SfgContext
...@@ -17,14 +17,13 @@ if TYPE_CHECKING: ...@@ -17,14 +17,13 @@ if TYPE_CHECKING:
class FlattenSequences(): class FlattenSequences():
"""Flattens any nested sequences occuring in a kernel call tree.""" """Flattens any nested sequences occuring in a kernel call tree."""
@visitor
def visit(self, node: SfgCallTreeNode) -> None: def visit(self, node: SfgCallTreeNode) -> None:
if isinstance(node, SfgSequence): for c in node.children:
return self._visit_SfgSequence(node) self.visit(c)
else:
for c in node.children:
self.visit(c)
def _visit_SfgSequence(self, sequence: SfgSequence) -> None: @visit.case(SfgSequence)
def sequence(self, sequence: SfgSequence) -> None:
children_flattened = [] children_flattened = []
def flatten(seq: SfgSequence): def flatten(seq: SfgSequence):
...@@ -61,18 +60,16 @@ class ExpandingParameterCollector(): ...@@ -61,18 +60,16 @@ class ExpandingParameterCollector():
self._flattener = FlattenSequences() self._flattener = FlattenSequences()
@visitor
def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]: def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]:
if isinstance(node, SfgCallTreeLeaf): return self.branching_node(node)
return self._visit_SfgCallTreeLeaf(node)
elif isinstance(node, SfgSequence): @visit.case(SfgCallTreeLeaf)
return self._visit_SfgSequence(node) def leaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]:
else:
return self._visit_branchingNode(node)
def _visit_SfgCallTreeLeaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]:
return leaf.required_parameters return leaf.required_parameters
def _visit_SfgSequence(self, sequence: SfgSequence) -> Set[TypedSymbol]: @visit.case(SfgSequence)
def sequence(self, sequence: SfgSequence) -> Set[TypedSymbol]:
""" """
Only in a sequence may parameters be defined and visible to subsequent nodes. Only in a sequence may parameters be defined and visible to subsequent nodes.
""" """
...@@ -99,7 +96,7 @@ class ExpandingParameterCollector(): ...@@ -99,7 +96,7 @@ class ExpandingParameterCollector():
return params return params
def _visit_branchingNode(self, node: SfgCallTreeNode): def branching_node(self, node: SfgCallTreeNode):
""" """
Each interior node that is not a sequence simply requires the union of all parameters Each interior node that is not a sequence simply requires the union of all parameters
required by its children. required by its children.
...@@ -113,18 +110,16 @@ class ParameterCollector(): ...@@ -113,18 +110,16 @@ class ParameterCollector():
Requires that all sequences in the tree are flattened. Requires that all sequences in the tree are flattened.
""" """
@visitor
def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]: def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]:
if isinstance(node, SfgCallTreeLeaf): return self.branching_node(node)
return self._visit_SfgCallTreeLeaf(node)
elif isinstance(node, SfgSequence): @visit.case(SfgCallTreeLeaf)
return self._visit_SfgSequence(node) def leaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]:
else:
return self._visit_branchingNode(node)
def _visit_SfgCallTreeLeaf(self, leaf: SfgCallTreeLeaf) -> Set[TypedSymbol]:
return leaf.required_parameters return leaf.required_parameters
def _visit_SfgSequence(self, sequence: SfgSequence) -> Set[TypedSymbol]: @visit.case(SfgSequence)
def sequence(self, sequence: SfgSequence) -> Set[TypedSymbol]:
""" """
Only in a sequence may parameters be defined and visible to subsequent nodes. Only in a sequence may parameters be defined and visible to subsequent nodes.
""" """
...@@ -138,7 +133,7 @@ class ParameterCollector(): ...@@ -138,7 +133,7 @@ class ParameterCollector():
params |= self.visit(c) params |= self.visit(c)
return params return params
def _visit_branchingNode(self, node: SfgCallTreeNode): def branching_node(self, node: SfgCallTreeNode):
""" """
Each interior node that is not a sequence simply requires the union of all parameters Each interior node that is not a sequence simply requires the union of all parameters
required by its children. required by its children.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment