From cb4a449d92f75ec03e99ede8989f7966d14eea84 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 13 Dec 2023 15:03:55 +0100 Subject: [PATCH] added switch-case --- src/pystencilssfg/composer.py | 66 ++++++++++++++++++++------- src/pystencilssfg/tree/conditional.py | 54 ++++++++++++++++++++-- 2 files changed, 99 insertions(+), 21 deletions(-) diff --git a/src/pystencilssfg/composer.py b/src/pystencilssfg/composer.py index c46ba07..cb51c56 100644 --- a/src/pystencilssfg/composer.py +++ b/src/pystencilssfg/composer.py @@ -16,7 +16,7 @@ from .tree import ( SfgBlock, ) from .tree.deferred_nodes import SfgDeferredFieldMapping -from .tree.conditional import SfgCondition, SfgCustomCondition, SfgBranch +from .tree.conditional import SfgCondition, SfgCustomCondition, SfgBranch, SfgSwitch from .source_components import ( SfgFunction, SfgHeaderInclude, @@ -82,8 +82,8 @@ class SfgComposer: return kns - def include(self, header_file: str): - self._ctx.add_include(parse_include(header_file)) + def include(self, header_file: str, private: bool = False): + self._ctx.add_include(parse_include(header_file, private)) def numpy_struct( self, name: str, dtype: np.dtype, add_constructor: bool = True @@ -154,7 +154,7 @@ class SfgComposer: """ return SfgKernelCallNode(kernel_handle) - def seq(self, *args: SfgCallTreeNode) -> SfgSequence: + def seq(self, *args: tuple | str | SfgCallTreeNode | SfgNodeBuilder) -> SfgSequence: """Syntax sequencing. For details, refer to [make_sequence][pystencilssfg.composer.make_sequence]""" return make_sequence(*args) @@ -180,6 +180,9 @@ class SfgComposer: """ return SfgBranchBuilder() + def switch(self, switch_arg: str | TypedSymbolOrObject) -> SfgSwitchBuilder: + return SfgSwitchBuilder(switch_arg) + def map_field(self, field: Field, src_object: SrcField) -> SfgDeferredFieldMapping: """Map a pystencils field to a field data structure, from which pointers, sizes and strides should be extracted. @@ -322,7 +325,37 @@ class SfgBranchBuilder(SfgNodeBuilder): return SfgBranch(self._cond, self._branch_true, self._branch_false) -def parse_include(incl: str | SfgHeaderInclude): +class SfgSwitchBuilder(SfgNodeBuilder): + def __init__(self, switch_arg: str | TypedSymbolOrObject): + self._switch_arg = switch_arg + self._cases: dict[str, SfgCallTreeNode] = dict() + self._default: SfgCallTreeNode | None = None + + def case(self, label: str): + if label in self._cases: + raise SfgException(f"Duplicate case: {label}") + + def sequencer(*args): + tree = make_sequence(*args) + self._cases[label] = tree + return self + + return sequencer + + def default(self, *args): + if self._default is not None: + raise SfgException("Duplicate default case") + + tree = make_sequence(*args) + self._default = tree + + return self + + def resolve(self) -> SfgCallTreeNode: + return SfgSwitch(self._switch_arg, self._cases, self._default) + + +def parse_include(incl: str | SfgHeaderInclude, private: bool = False): if isinstance(incl, SfgHeaderInclude): return incl @@ -331,7 +364,7 @@ def parse_include(incl: str | SfgHeaderInclude): incl = incl[1:-1] system_header = True - return SfgHeaderInclude(incl, system_header=system_header) + return SfgHeaderInclude(incl, system_header=system_header, private=private) class SfgClassComposer: @@ -347,10 +380,9 @@ class SfgClassComposer: def __call__( self, - *args: SfgClassMember - | SfgClassComposer.ConstructorBuilder - | SrcObject - | str, + *args: ( + SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str + ), ): for arg in args: self._vis_block.append_member(SfgClassComposer._resolve_member(arg)) @@ -430,11 +462,13 @@ class SfgClassComposer: self._ctx.add_class(cls) def sequencer( - *args: SfgClassComposer.VisibilityContext - | SfgClassMember - | SfgClassComposer.ConstructorBuilder - | SrcObject - | str, + *args: ( + SfgClassComposer.VisibilityContext + | SfgClassMember + | SfgClassComposer.ConstructorBuilder + | SrcObject + | str + ), ): default_ended = False @@ -465,7 +499,7 @@ class SfgClassComposer: @staticmethod def _resolve_member( - arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str, + arg: (SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str), ): if isinstance(arg, SrcObject): return SfgMemberVariable(arg.name, arg.dtype) diff --git a/src/pystencilssfg/tree/conditional.py b/src/pystencilssfg/tree/conditional.py index 4c9021a..4a35904 100644 --- a/src/pystencilssfg/tree/conditional.py +++ b/src/pystencilssfg/tree/conditional.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Optional, cast, Generator from pystencils.typing import TypedSymbol, BasicType @@ -60,10 +60,12 @@ class IntOdd(SfgCondition): class SfgBranch(SfgCallTreeNode): - def __init__(self, - cond: SfgCondition, - branch_true: SfgCallTreeNode, - branch_false: Optional[SfgCallTreeNode] = None): + def __init__( + self, + cond: SfgCondition, + branch_true: SfgCallTreeNode, + branch_false: Optional[SfgCallTreeNode] = None, + ): super().__init__(cond, branch_true, *((branch_false,) if branch_false else ())) @property @@ -89,3 +91,45 @@ class SfgBranch(SfgCallTreeNode): code += "\n}" return code + + +class SfgSwitch(SfgCallTreeNode): + def __init__( + self, + switch_arg: str | TypedSymbolOrObject, + cases_dict: dict[str, SfgCallTreeNode], + default: SfgCallTreeNode | None = None, + ): + children = tuple(cases_dict.values()) + ( + (default,) if default is not None else () + ) + super().__init__(*children) + self._switch_arg = switch_arg + self._cases_dict = cases_dict + self._default = default + + @property + def switch_arg(self) -> str | TypedSymbolOrObject: + return self._switch_arg + + def cases(self) -> Generator[tuple[str, SfgCallTreeNode], None, None]: + yield from self._cases_dict.items() + + @property + def default(self) -> SfgCallTreeNode | None: + return self._default + + def get_code(self, ctx: SfgContext) -> str: + code = f"switch({self._switch_arg}) {{\n" + for label, subtree in self._cases_dict.items(): + code += f"case {label}: {{\n" + code += ctx.codestyle.indent(subtree.get_code(ctx)) + code += "\nbreak;\n}\n" + + if self._default is not None: + code += "default: {\n" + code += ctx.codestyle.indent(self._default.get_code(ctx)) + code += "\nbreak;\n}\n" + + code += "}" + return code -- GitLab