From b6fac0d2449ff31432bbbe6e6c8c9b92cd6267df Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 18 Dec 2023 13:27:26 +0100
Subject: [PATCH] fixed SfgSwitchCase to be compliant with base class interface

---
 src/pystencilssfg/tree/basic_nodes.py    |  33 +++++++-
 src/pystencilssfg/tree/conditional.py    | 101 +++++++++++++++++++----
 src/pystencilssfg/visitors/dispatcher.py |   2 +-
 3 files changed, 114 insertions(+), 22 deletions(-)

diff --git a/src/pystencilssfg/tree/basic_nodes.py b/src/pystencilssfg/tree/basic_nodes.py
index 8561759..e209e0a 100644
--- a/src/pystencilssfg/tree/basic_nodes.py
+++ b/src/pystencilssfg/tree/basic_nodes.py
@@ -14,7 +14,18 @@ if TYPE_CHECKING:
 class SfgCallTreeNode(ABC):
     """Base class for all nodes comprising SFG call trees.
 
-    Any instantiable call tree node must implement `get_code`.
+    ## Code Printing
+
+    For extensibility, code printing is implemented inside the call tree.
+    Therefore, every instantiable call tree node must implement the method `get_code`.
+    By convention, the string returned by `get_code` should not contain a trailing newline.
+
+    ## Branching Structure
+
+    The branching structure of the call tree is managed uniformly through the `children` interface
+    of SfgCallTreeNode. Each subclass must ensure that access to and modification of
+    the branching structure through the `children` property and the `child` and `set_child`
+    methods is possible, if necessary by overriding the property and methods.
     """
 
     def __init__(self, *children: SfgCallTreeNode):
@@ -22,22 +33,29 @@ class SfgCallTreeNode(ABC):
 
     @property
     def children(self) -> tuple[SfgCallTreeNode, ...]:
+        """This node's children"""
         return tuple(self._children)
 
     @children.setter
     def children(self, cs: Sequence[SfgCallTreeNode]) -> None:
+        """Replaces this node's children. By default, the number of child nodes must not change."""
         if len(cs) != len(self._children):
             raise ValueError("The number of child nodes must remain the same!")
         self._children = list(cs)
 
     def child(self, idx: int) -> SfgCallTreeNode:
+        """Gets the child at index idx."""
         return self._children[idx]
 
+    def set_child(self, idx: int, c: SfgCallTreeNode):
+        """Replaces the child at index idx."""
+        self._children[idx] = c
+
     def __getitem__(self, idx: int) -> SfgCallTreeNode:
-        return self._children[idx]
+        return self.child(idx)
 
     def __setitem__(self, idx: int, c: SfgCallTreeNode) -> None:
-        self._children[idx] = c
+        self.set_child(idx, c)
 
     @abstractmethod
     def get_code(self, ctx: SfgContext) -> str:
@@ -183,6 +201,15 @@ class SfgBlock(SfgCallTreeNode):
         return "{\n" + subtree_code + "\n}"
 
 
+# class SfgForLoop(SfgCallTreeNode):
+#     def __init__(self, control_line: SfgStatements, body: SfgCallTreeNode):
+#         super().__init__(control_line, body)
+
+#     @property
+#     def body(self) -> SfgStatements:
+#         return cast(SfgStatements)
+
+
 class SfgKernelCallNode(SfgCallTreeLeaf):
     def __init__(self, kernel_handle: SfgKernelHandle):
         super().__init__()
diff --git a/src/pystencilssfg/tree/conditional.py b/src/pystencilssfg/tree/conditional.py
index 4a35904..65f1f87 100644
--- a/src/pystencilssfg/tree/conditional.py
+++ b/src/pystencilssfg/tree/conditional.py
@@ -1,5 +1,5 @@
 from __future__ import annotations
-from typing import TYPE_CHECKING, Optional, cast, Generator
+from typing import TYPE_CHECKING, Optional, cast, Generator, Sequence, NewType
 
 from pystencils.typing import TypedSymbol, BasicType
 
@@ -93,6 +93,37 @@ class SfgBranch(SfgCallTreeNode):
         return code
 
 
+class SfgSwitchCase(SfgCallTreeNode):
+    DefaultCaseType = NewType("DefaultCaseType", object)
+    Default = DefaultCaseType(object())
+
+    def __init__(self, label: str | DefaultCaseType, body: SfgCallTreeNode):
+        self._label = label
+        super().__init__(body)
+
+    @property
+    def label(self) -> str | DefaultCaseType:
+        return self._label
+
+    @property
+    def body(self) -> SfgCallTreeNode:
+        return self._children[0]
+
+    @property
+    def is_default(self) -> bool:
+        return self._label == SfgSwitchCase.Default
+
+    def get_code(self, ctx: SfgContext) -> str:
+        code = ""
+        if self._label == SfgSwitchCase.Default:
+            code += "default: {\n"
+        else:
+            code += f"case {self._label}: {{\n"
+        code += ctx.codestyle.indent(self.body.get_code(ctx))
+        code += "\nbreak;\n}"
+        return code
+
+
 class SfgSwitch(SfgCallTreeNode):
     def __init__(
         self,
@@ -100,36 +131,70 @@ class SfgSwitch(SfgCallTreeNode):
         cases_dict: dict[str, SfgCallTreeNode],
         default: SfgCallTreeNode | None = None,
     ):
-        children = tuple(cases_dict.values()) + (
-            (default,) if default is not None else ()
-        )
-        super().__init__(*children)
+        children = [SfgSwitchCase(label, body) for label, body in cases_dict.items()]
+        if default is not None:
+            # invariant: the default case is always the last child
+            children += [SfgSwitchCase(SfgSwitchCase.Default, default)]
         self._switch_arg = switch_arg
-        self._cases_dict = cases_dict
         self._default = default
+        super().__init__(*children)
 
     @property
     def switch_arg(self) -> str | TypedSymbolOrObject:
         return self._switch_arg
 
-    def cases(self) -> Generator[tuple[str, SfgCallTreeNode], None, None]:
-        yield from self._cases_dict.items()
+    def cases(self) -> Generator[SfgCallTreeNode, None, None]:
+        if self._default is not None:
+            yield from self._children[:-1]
+        else:
+            yield from self._children
 
     @property
     def default(self) -> SfgCallTreeNode | None:
         return self._default
 
+    @property
+    def children(self) -> tuple[SfgCallTreeNode, ...]:
+        return tuple(self._children)
+
+    @children.setter
+    def children(self, cs: Sequence[SfgCallTreeNode]) -> None:
+        if len(cs) != len(self._children):
+            raise ValueError("The number of child nodes must remain the same!")
+
+        self._default = None
+        for i, c in enumerate(cs):
+            if not isinstance(c, SfgSwitchCase):
+                raise ValueError(
+                    "An SfgSwitch node can only have SfgSwitchCases as children."
+                )
+            if c.is_default:
+                if i != len(cs) - 1:
+                    raise ValueError("Default case must be listed last.")
+                else:
+                    self._default = c
+
+        self._children = list(cs)
+
+    def set_child(self, idx: int, c: SfgCallTreeNode):
+        if not isinstance(c, SfgSwitchCase):
+            raise ValueError(
+                "An SfgSwitch node can only have SfgSwitchCases as children."
+            )
+
+        if c.is_default:
+            if idx != len(self._children) - 1:
+                raise ValueError("Default case must be the last child.")
+            elif self._default is None:
+                raise ValueError("Cannot replace normal case with default case.")
+            else:
+                self._default = c
+                self._children[-1] = c
+        else:
+            self._children[idx] = c
+
     def get_code(self, ctx: SfgContext) -> str:
         code = f"switch({self._switch_arg}) {{\n"
-        for label, subtree in self._cases_dict.items():
-            code += f"case {label}: {{\n"
-            code += ctx.codestyle.indent(subtree.get_code(ctx))
-            code += "\nbreak;\n}\n"
-
-        if self._default is not None:
-            code += "default: {\n"
-            code += ctx.codestyle.indent(self._default.get_code(ctx))
-            code += "\nbreak;\n}\n"
-
+        code += "\n".join(c.get_code(ctx) for c in self.children)
         code += "}"
         return code
diff --git a/src/pystencilssfg/visitors/dispatcher.py b/src/pystencilssfg/visitors/dispatcher.py
index f0bc005..85a0f08 100644
--- a/src/pystencilssfg/visitors/dispatcher.py
+++ b/src/pystencilssfg/visitors/dispatcher.py
@@ -1,5 +1,5 @@
 from __future__ import annotations
-from typing import Callable, TypeVar, Generic, ParamSpec
+from typing import Callable, TypeVar, Generic
 from types import MethodType
 
 from functools import wraps
-- 
GitLab