From cb4a449d92f75ec03e99ede8989f7966d14eea84 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 13 Dec 2023 15:03:55 +0100
Subject: [PATCH] added switch-case

---
 src/pystencilssfg/composer.py         | 66 ++++++++++++++++++++-------
 src/pystencilssfg/tree/conditional.py | 54 ++++++++++++++++++++--
 2 files changed, 99 insertions(+), 21 deletions(-)

diff --git a/src/pystencilssfg/composer.py b/src/pystencilssfg/composer.py
index c46ba07..cb51c56 100644
--- a/src/pystencilssfg/composer.py
+++ b/src/pystencilssfg/composer.py
@@ -16,7 +16,7 @@ from .tree import (
     SfgBlock,
 )
 from .tree.deferred_nodes import SfgDeferredFieldMapping
-from .tree.conditional import SfgCondition, SfgCustomCondition, SfgBranch
+from .tree.conditional import SfgCondition, SfgCustomCondition, SfgBranch, SfgSwitch
 from .source_components import (
     SfgFunction,
     SfgHeaderInclude,
@@ -82,8 +82,8 @@ class SfgComposer:
 
         return kns
 
-    def include(self, header_file: str):
-        self._ctx.add_include(parse_include(header_file))
+    def include(self, header_file: str, private: bool = False):
+        self._ctx.add_include(parse_include(header_file, private))
 
     def numpy_struct(
         self, name: str, dtype: np.dtype, add_constructor: bool = True
@@ -154,7 +154,7 @@ class SfgComposer:
         """
         return SfgKernelCallNode(kernel_handle)
 
-    def seq(self, *args: SfgCallTreeNode) -> SfgSequence:
+    def seq(self, *args: tuple | str | SfgCallTreeNode | SfgNodeBuilder) -> SfgSequence:
         """Syntax sequencing. For details, refer to [make_sequence][pystencilssfg.composer.make_sequence]"""
         return make_sequence(*args)
 
@@ -180,6 +180,9 @@ class SfgComposer:
         """
         return SfgBranchBuilder()
 
+    def switch(self, switch_arg: str | TypedSymbolOrObject) -> SfgSwitchBuilder:
+        return SfgSwitchBuilder(switch_arg)
+
     def map_field(self, field: Field, src_object: SrcField) -> SfgDeferredFieldMapping:
         """Map a pystencils field to a field data structure, from which pointers, sizes
         and strides should be extracted.
@@ -322,7 +325,37 @@ class SfgBranchBuilder(SfgNodeBuilder):
         return SfgBranch(self._cond, self._branch_true, self._branch_false)
 
 
-def parse_include(incl: str | SfgHeaderInclude):
+class SfgSwitchBuilder(SfgNodeBuilder):
+    def __init__(self, switch_arg: str | TypedSymbolOrObject):
+        self._switch_arg = switch_arg
+        self._cases: dict[str, SfgCallTreeNode] = dict()
+        self._default: SfgCallTreeNode | None = None
+
+    def case(self, label: str):
+        if label in self._cases:
+            raise SfgException(f"Duplicate case: {label}")
+
+        def sequencer(*args):
+            tree = make_sequence(*args)
+            self._cases[label] = tree
+            return self
+
+        return sequencer
+
+    def default(self, *args):
+        if self._default is not None:
+            raise SfgException("Duplicate default case")
+
+        tree = make_sequence(*args)
+        self._default = tree
+
+        return self
+
+    def resolve(self) -> SfgCallTreeNode:
+        return SfgSwitch(self._switch_arg, self._cases, self._default)
+
+
+def parse_include(incl: str | SfgHeaderInclude, private: bool = False):
     if isinstance(incl, SfgHeaderInclude):
         return incl
 
@@ -331,7 +364,7 @@ def parse_include(incl: str | SfgHeaderInclude):
         incl = incl[1:-1]
         system_header = True
 
-    return SfgHeaderInclude(incl, system_header=system_header)
+    return SfgHeaderInclude(incl, system_header=system_header, private=private)
 
 
 class SfgClassComposer:
@@ -347,10 +380,9 @@ class SfgClassComposer:
 
         def __call__(
             self,
-            *args: SfgClassMember
-            | SfgClassComposer.ConstructorBuilder
-            | SrcObject
-            | str,
+            *args: (
+                SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str
+            ),
         ):
             for arg in args:
                 self._vis_block.append_member(SfgClassComposer._resolve_member(arg))
@@ -430,11 +462,13 @@ class SfgClassComposer:
         self._ctx.add_class(cls)
 
         def sequencer(
-            *args: SfgClassComposer.VisibilityContext
-            | SfgClassMember
-            | SfgClassComposer.ConstructorBuilder
-            | SrcObject
-            | str,
+            *args: (
+                SfgClassComposer.VisibilityContext
+                | SfgClassMember
+                | SfgClassComposer.ConstructorBuilder
+                | SrcObject
+                | str
+            ),
         ):
             default_ended = False
 
@@ -465,7 +499,7 @@ class SfgClassComposer:
 
     @staticmethod
     def _resolve_member(
-        arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str,
+        arg: (SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str),
     ):
         if isinstance(arg, SrcObject):
             return SfgMemberVariable(arg.name, arg.dtype)
diff --git a/src/pystencilssfg/tree/conditional.py b/src/pystencilssfg/tree/conditional.py
index 4c9021a..4a35904 100644
--- a/src/pystencilssfg/tree/conditional.py
+++ b/src/pystencilssfg/tree/conditional.py
@@ -1,5 +1,5 @@
 from __future__ import annotations
-from typing import TYPE_CHECKING, Optional, cast
+from typing import TYPE_CHECKING, Optional, cast, Generator
 
 from pystencils.typing import TypedSymbol, BasicType
 
@@ -60,10 +60,12 @@ class IntOdd(SfgCondition):
 
 
 class SfgBranch(SfgCallTreeNode):
-    def __init__(self,
-                 cond: SfgCondition,
-                 branch_true: SfgCallTreeNode,
-                 branch_false: Optional[SfgCallTreeNode] = None):
+    def __init__(
+        self,
+        cond: SfgCondition,
+        branch_true: SfgCallTreeNode,
+        branch_false: Optional[SfgCallTreeNode] = None,
+    ):
         super().__init__(cond, branch_true, *((branch_false,) if branch_false else ()))
 
     @property
@@ -89,3 +91,45 @@ class SfgBranch(SfgCallTreeNode):
             code += "\n}"
 
         return code
+
+
+class SfgSwitch(SfgCallTreeNode):
+    def __init__(
+        self,
+        switch_arg: str | TypedSymbolOrObject,
+        cases_dict: dict[str, SfgCallTreeNode],
+        default: SfgCallTreeNode | None = None,
+    ):
+        children = tuple(cases_dict.values()) + (
+            (default,) if default is not None else ()
+        )
+        super().__init__(*children)
+        self._switch_arg = switch_arg
+        self._cases_dict = cases_dict
+        self._default = default
+
+    @property
+    def switch_arg(self) -> str | TypedSymbolOrObject:
+        return self._switch_arg
+
+    def cases(self) -> Generator[tuple[str, SfgCallTreeNode], None, None]:
+        yield from self._cases_dict.items()
+
+    @property
+    def default(self) -> SfgCallTreeNode | None:
+        return self._default
+
+    def get_code(self, ctx: SfgContext) -> str:
+        code = f"switch({self._switch_arg}) {{\n"
+        for label, subtree in self._cases_dict.items():
+            code += f"case {label}: {{\n"
+            code += ctx.codestyle.indent(subtree.get_code(ctx))
+            code += "\nbreak;\n}\n"
+
+        if self._default is not None:
+            code += "default: {\n"
+            code += ctx.codestyle.indent(self._default.get_code(ctx))
+            code += "\nbreak;\n}\n"
+
+        code += "}"
+        return code
-- 
GitLab