From b6fac0d2449ff31432bbbe6e6c8c9b92cd6267df Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 18 Dec 2023 13:27:26 +0100 Subject: [PATCH] fixed SfgSwitchCase to be compliant with base class interface --- src/pystencilssfg/tree/basic_nodes.py | 33 +++++++- src/pystencilssfg/tree/conditional.py | 101 +++++++++++++++++++---- src/pystencilssfg/visitors/dispatcher.py | 2 +- 3 files changed, 114 insertions(+), 22 deletions(-) diff --git a/src/pystencilssfg/tree/basic_nodes.py b/src/pystencilssfg/tree/basic_nodes.py index 8561759..e209e0a 100644 --- a/src/pystencilssfg/tree/basic_nodes.py +++ b/src/pystencilssfg/tree/basic_nodes.py @@ -14,7 +14,18 @@ if TYPE_CHECKING: class SfgCallTreeNode(ABC): """Base class for all nodes comprising SFG call trees. - Any instantiable call tree node must implement `get_code`. + ## Code Printing + + For extensibility, code printing is implemented inside the call tree. + Therefore, every instantiable call tree node must implement the method `get_code`. + By convention, the string returned by `get_code` should not contain a trailing newline. + + ## Branching Structure + + The branching structure of the call tree is managed uniformly through the `children` interface + of SfgCallTreeNode. Each subclass must ensure that access to and modification of + the branching structure through the `children` property and the `child` and `set_child` + methods is possible, if necessary by overriding the property and methods. """ def __init__(self, *children: SfgCallTreeNode): @@ -22,22 +33,29 @@ class SfgCallTreeNode(ABC): @property def children(self) -> tuple[SfgCallTreeNode, ...]: + """This node's children""" return tuple(self._children) @children.setter def children(self, cs: Sequence[SfgCallTreeNode]) -> None: + """Replaces this node's children. By default, the number of child nodes must not change.""" if len(cs) != len(self._children): raise ValueError("The number of child nodes must remain the same!") self._children = list(cs) def child(self, idx: int) -> SfgCallTreeNode: + """Gets the child at index idx.""" return self._children[idx] + def set_child(self, idx: int, c: SfgCallTreeNode): + """Replaces the child at index idx.""" + self._children[idx] = c + def __getitem__(self, idx: int) -> SfgCallTreeNode: - return self._children[idx] + return self.child(idx) def __setitem__(self, idx: int, c: SfgCallTreeNode) -> None: - self._children[idx] = c + self.set_child(idx, c) @abstractmethod def get_code(self, ctx: SfgContext) -> str: @@ -183,6 +201,15 @@ class SfgBlock(SfgCallTreeNode): return "{\n" + subtree_code + "\n}" +# class SfgForLoop(SfgCallTreeNode): +# def __init__(self, control_line: SfgStatements, body: SfgCallTreeNode): +# super().__init__(control_line, body) + +# @property +# def body(self) -> SfgStatements: +# return cast(SfgStatements) + + class SfgKernelCallNode(SfgCallTreeLeaf): def __init__(self, kernel_handle: SfgKernelHandle): super().__init__() diff --git a/src/pystencilssfg/tree/conditional.py b/src/pystencilssfg/tree/conditional.py index 4a35904..65f1f87 100644 --- a/src/pystencilssfg/tree/conditional.py +++ b/src/pystencilssfg/tree/conditional.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, cast, Generator +from typing import TYPE_CHECKING, Optional, cast, Generator, Sequence, NewType from pystencils.typing import TypedSymbol, BasicType @@ -93,6 +93,37 @@ class SfgBranch(SfgCallTreeNode): return code +class SfgSwitchCase(SfgCallTreeNode): + DefaultCaseType = NewType("DefaultCaseType", object) + Default = DefaultCaseType(object()) + + def __init__(self, label: str | DefaultCaseType, body: SfgCallTreeNode): + self._label = label + super().__init__(body) + + @property + def label(self) -> str | DefaultCaseType: + return self._label + + @property + def body(self) -> SfgCallTreeNode: + return self._children[0] + + @property + def is_default(self) -> bool: + return self._label == SfgSwitchCase.Default + + def get_code(self, ctx: SfgContext) -> str: + code = "" + if self._label == SfgSwitchCase.Default: + code += "default: {\n" + else: + code += f"case {self._label}: {{\n" + code += ctx.codestyle.indent(self.body.get_code(ctx)) + code += "\nbreak;\n}" + return code + + class SfgSwitch(SfgCallTreeNode): def __init__( self, @@ -100,36 +131,70 @@ class SfgSwitch(SfgCallTreeNode): cases_dict: dict[str, SfgCallTreeNode], default: SfgCallTreeNode | None = None, ): - children = tuple(cases_dict.values()) + ( - (default,) if default is not None else () - ) - super().__init__(*children) + children = [SfgSwitchCase(label, body) for label, body in cases_dict.items()] + if default is not None: + # invariant: the default case is always the last child + children += [SfgSwitchCase(SfgSwitchCase.Default, default)] self._switch_arg = switch_arg - self._cases_dict = cases_dict self._default = default + super().__init__(*children) @property def switch_arg(self) -> str | TypedSymbolOrObject: return self._switch_arg - def cases(self) -> Generator[tuple[str, SfgCallTreeNode], None, None]: - yield from self._cases_dict.items() + def cases(self) -> Generator[SfgCallTreeNode, None, None]: + if self._default is not None: + yield from self._children[:-1] + else: + yield from self._children @property def default(self) -> SfgCallTreeNode | None: return self._default + @property + def children(self) -> tuple[SfgCallTreeNode, ...]: + return tuple(self._children) + + @children.setter + def children(self, cs: Sequence[SfgCallTreeNode]) -> None: + if len(cs) != len(self._children): + raise ValueError("The number of child nodes must remain the same!") + + self._default = None + for i, c in enumerate(cs): + if not isinstance(c, SfgSwitchCase): + raise ValueError( + "An SfgSwitch node can only have SfgSwitchCases as children." + ) + if c.is_default: + if i != len(cs) - 1: + raise ValueError("Default case must be listed last.") + else: + self._default = c + + self._children = list(cs) + + def set_child(self, idx: int, c: SfgCallTreeNode): + if not isinstance(c, SfgSwitchCase): + raise ValueError( + "An SfgSwitch node can only have SfgSwitchCases as children." + ) + + if c.is_default: + if idx != len(self._children) - 1: + raise ValueError("Default case must be the last child.") + elif self._default is None: + raise ValueError("Cannot replace normal case with default case.") + else: + self._default = c + self._children[-1] = c + else: + self._children[idx] = c + def get_code(self, ctx: SfgContext) -> str: code = f"switch({self._switch_arg}) {{\n" - for label, subtree in self._cases_dict.items(): - code += f"case {label}: {{\n" - code += ctx.codestyle.indent(subtree.get_code(ctx)) - code += "\nbreak;\n}\n" - - if self._default is not None: - code += "default: {\n" - code += ctx.codestyle.indent(self._default.get_code(ctx)) - code += "\nbreak;\n}\n" - + code += "\n".join(c.get_code(ctx) for c in self.children) code += "}" return code diff --git a/src/pystencilssfg/visitors/dispatcher.py b/src/pystencilssfg/visitors/dispatcher.py index f0bc005..85a0f08 100644 --- a/src/pystencilssfg/visitors/dispatcher.py +++ b/src/pystencilssfg/visitors/dispatcher.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Callable, TypeVar, Generic, ParamSpec +from typing import Callable, TypeVar, Generic from types import MethodType from functools import wraps -- GitLab