Skip to content
Snippets Groups Projects
Commit b6fac0d2 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fixed SfgSwitchCase to be compliant with base class interface

parent 8357464e
Branches
No related merge requests found
......@@ -14,7 +14,18 @@ if TYPE_CHECKING:
class SfgCallTreeNode(ABC):
"""Base class for all nodes comprising SFG call trees.
Any instantiable call tree node must implement `get_code`.
## Code Printing
For extensibility, code printing is implemented inside the call tree.
Therefore, every instantiable call tree node must implement the method `get_code`.
By convention, the string returned by `get_code` should not contain a trailing newline.
## Branching Structure
The branching structure of the call tree is managed uniformly through the `children` interface
of SfgCallTreeNode. Each subclass must ensure that access to and modification of
the branching structure through the `children` property and the `child` and `set_child`
methods is possible, if necessary by overriding the property and methods.
"""
def __init__(self, *children: SfgCallTreeNode):
......@@ -22,22 +33,29 @@ class SfgCallTreeNode(ABC):
@property
def children(self) -> tuple[SfgCallTreeNode, ...]:
"""This node's children"""
return tuple(self._children)
@children.setter
def children(self, cs: Sequence[SfgCallTreeNode]) -> None:
"""Replaces this node's children. By default, the number of child nodes must not change."""
if len(cs) != len(self._children):
raise ValueError("The number of child nodes must remain the same!")
self._children = list(cs)
def child(self, idx: int) -> SfgCallTreeNode:
"""Gets the child at index idx."""
return self._children[idx]
def set_child(self, idx: int, c: SfgCallTreeNode):
"""Replaces the child at index idx."""
self._children[idx] = c
def __getitem__(self, idx: int) -> SfgCallTreeNode:
return self._children[idx]
return self.child(idx)
def __setitem__(self, idx: int, c: SfgCallTreeNode) -> None:
self._children[idx] = c
self.set_child(idx, c)
@abstractmethod
def get_code(self, ctx: SfgContext) -> str:
......@@ -183,6 +201,15 @@ class SfgBlock(SfgCallTreeNode):
return "{\n" + subtree_code + "\n}"
# class SfgForLoop(SfgCallTreeNode):
# def __init__(self, control_line: SfgStatements, body: SfgCallTreeNode):
# super().__init__(control_line, body)
# @property
# def body(self) -> SfgStatements:
# return cast(SfgStatements)
class SfgKernelCallNode(SfgCallTreeLeaf):
def __init__(self, kernel_handle: SfgKernelHandle):
super().__init__()
......
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, cast, Generator
from typing import TYPE_CHECKING, Optional, cast, Generator, Sequence, NewType
from pystencils.typing import TypedSymbol, BasicType
......@@ -93,6 +93,37 @@ class SfgBranch(SfgCallTreeNode):
return code
class SfgSwitchCase(SfgCallTreeNode):
DefaultCaseType = NewType("DefaultCaseType", object)
Default = DefaultCaseType(object())
def __init__(self, label: str | DefaultCaseType, body: SfgCallTreeNode):
self._label = label
super().__init__(body)
@property
def label(self) -> str | DefaultCaseType:
return self._label
@property
def body(self) -> SfgCallTreeNode:
return self._children[0]
@property
def is_default(self) -> bool:
return self._label == SfgSwitchCase.Default
def get_code(self, ctx: SfgContext) -> str:
code = ""
if self._label == SfgSwitchCase.Default:
code += "default: {\n"
else:
code += f"case {self._label}: {{\n"
code += ctx.codestyle.indent(self.body.get_code(ctx))
code += "\nbreak;\n}"
return code
class SfgSwitch(SfgCallTreeNode):
def __init__(
self,
......@@ -100,36 +131,70 @@ class SfgSwitch(SfgCallTreeNode):
cases_dict: dict[str, SfgCallTreeNode],
default: SfgCallTreeNode | None = None,
):
children = tuple(cases_dict.values()) + (
(default,) if default is not None else ()
)
super().__init__(*children)
children = [SfgSwitchCase(label, body) for label, body in cases_dict.items()]
if default is not None:
# invariant: the default case is always the last child
children += [SfgSwitchCase(SfgSwitchCase.Default, default)]
self._switch_arg = switch_arg
self._cases_dict = cases_dict
self._default = default
super().__init__(*children)
@property
def switch_arg(self) -> str | TypedSymbolOrObject:
return self._switch_arg
def cases(self) -> Generator[tuple[str, SfgCallTreeNode], None, None]:
yield from self._cases_dict.items()
def cases(self) -> Generator[SfgCallTreeNode, None, None]:
if self._default is not None:
yield from self._children[:-1]
else:
yield from self._children
@property
def default(self) -> SfgCallTreeNode | None:
return self._default
@property
def children(self) -> tuple[SfgCallTreeNode, ...]:
return tuple(self._children)
@children.setter
def children(self, cs: Sequence[SfgCallTreeNode]) -> None:
if len(cs) != len(self._children):
raise ValueError("The number of child nodes must remain the same!")
self._default = None
for i, c in enumerate(cs):
if not isinstance(c, SfgSwitchCase):
raise ValueError(
"An SfgSwitch node can only have SfgSwitchCases as children."
)
if c.is_default:
if i != len(cs) - 1:
raise ValueError("Default case must be listed last.")
else:
self._default = c
self._children = list(cs)
def set_child(self, idx: int, c: SfgCallTreeNode):
if not isinstance(c, SfgSwitchCase):
raise ValueError(
"An SfgSwitch node can only have SfgSwitchCases as children."
)
if c.is_default:
if idx != len(self._children) - 1:
raise ValueError("Default case must be the last child.")
elif self._default is None:
raise ValueError("Cannot replace normal case with default case.")
else:
self._default = c
self._children[-1] = c
else:
self._children[idx] = c
def get_code(self, ctx: SfgContext) -> str:
code = f"switch({self._switch_arg}) {{\n"
for label, subtree in self._cases_dict.items():
code += f"case {label}: {{\n"
code += ctx.codestyle.indent(subtree.get_code(ctx))
code += "\nbreak;\n}\n"
if self._default is not None:
code += "default: {\n"
code += ctx.codestyle.indent(self._default.get_code(ctx))
code += "\nbreak;\n}\n"
code += "\n".join(c.get_code(ctx) for c in self.children)
code += "}"
return code
from __future__ import annotations
from typing import Callable, TypeVar, Generic, ParamSpec
from typing import Callable, TypeVar, Generic
from types import MethodType
from functools import wraps
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment