Skip to content
Snippets Groups Projects
Commit cb4a449d authored by Frederik Hennig's avatar Frederik Hennig
Browse files

added switch-case

parent ed11c419
Branches
Tags
No related merge requests found
Pipeline #58322 passed
...@@ -16,7 +16,7 @@ from .tree import ( ...@@ -16,7 +16,7 @@ from .tree import (
SfgBlock, SfgBlock,
) )
from .tree.deferred_nodes import SfgDeferredFieldMapping from .tree.deferred_nodes import SfgDeferredFieldMapping
from .tree.conditional import SfgCondition, SfgCustomCondition, SfgBranch from .tree.conditional import SfgCondition, SfgCustomCondition, SfgBranch, SfgSwitch
from .source_components import ( from .source_components import (
SfgFunction, SfgFunction,
SfgHeaderInclude, SfgHeaderInclude,
...@@ -82,8 +82,8 @@ class SfgComposer: ...@@ -82,8 +82,8 @@ class SfgComposer:
return kns return kns
def include(self, header_file: str): def include(self, header_file: str, private: bool = False):
self._ctx.add_include(parse_include(header_file)) self._ctx.add_include(parse_include(header_file, private))
def numpy_struct( def numpy_struct(
self, name: str, dtype: np.dtype, add_constructor: bool = True self, name: str, dtype: np.dtype, add_constructor: bool = True
...@@ -154,7 +154,7 @@ class SfgComposer: ...@@ -154,7 +154,7 @@ class SfgComposer:
""" """
return SfgKernelCallNode(kernel_handle) return SfgKernelCallNode(kernel_handle)
def seq(self, *args: SfgCallTreeNode) -> SfgSequence: def seq(self, *args: tuple | str | SfgCallTreeNode | SfgNodeBuilder) -> SfgSequence:
"""Syntax sequencing. For details, refer to [make_sequence][pystencilssfg.composer.make_sequence]""" """Syntax sequencing. For details, refer to [make_sequence][pystencilssfg.composer.make_sequence]"""
return make_sequence(*args) return make_sequence(*args)
...@@ -180,6 +180,9 @@ class SfgComposer: ...@@ -180,6 +180,9 @@ class SfgComposer:
""" """
return SfgBranchBuilder() return SfgBranchBuilder()
def switch(self, switch_arg: str | TypedSymbolOrObject) -> SfgSwitchBuilder:
return SfgSwitchBuilder(switch_arg)
def map_field(self, field: Field, src_object: SrcField) -> SfgDeferredFieldMapping: def map_field(self, field: Field, src_object: SrcField) -> SfgDeferredFieldMapping:
"""Map a pystencils field to a field data structure, from which pointers, sizes """Map a pystencils field to a field data structure, from which pointers, sizes
and strides should be extracted. and strides should be extracted.
...@@ -322,7 +325,37 @@ class SfgBranchBuilder(SfgNodeBuilder): ...@@ -322,7 +325,37 @@ class SfgBranchBuilder(SfgNodeBuilder):
return SfgBranch(self._cond, self._branch_true, self._branch_false) return SfgBranch(self._cond, self._branch_true, self._branch_false)
def parse_include(incl: str | SfgHeaderInclude): class SfgSwitchBuilder(SfgNodeBuilder):
def __init__(self, switch_arg: str | TypedSymbolOrObject):
self._switch_arg = switch_arg
self._cases: dict[str, SfgCallTreeNode] = dict()
self._default: SfgCallTreeNode | None = None
def case(self, label: str):
if label in self._cases:
raise SfgException(f"Duplicate case: {label}")
def sequencer(*args):
tree = make_sequence(*args)
self._cases[label] = tree
return self
return sequencer
def default(self, *args):
if self._default is not None:
raise SfgException("Duplicate default case")
tree = make_sequence(*args)
self._default = tree
return self
def resolve(self) -> SfgCallTreeNode:
return SfgSwitch(self._switch_arg, self._cases, self._default)
def parse_include(incl: str | SfgHeaderInclude, private: bool = False):
if isinstance(incl, SfgHeaderInclude): if isinstance(incl, SfgHeaderInclude):
return incl return incl
...@@ -331,7 +364,7 @@ def parse_include(incl: str | SfgHeaderInclude): ...@@ -331,7 +364,7 @@ def parse_include(incl: str | SfgHeaderInclude):
incl = incl[1:-1] incl = incl[1:-1]
system_header = True system_header = True
return SfgHeaderInclude(incl, system_header=system_header) return SfgHeaderInclude(incl, system_header=system_header, private=private)
class SfgClassComposer: class SfgClassComposer:
...@@ -347,10 +380,9 @@ class SfgClassComposer: ...@@ -347,10 +380,9 @@ class SfgClassComposer:
def __call__( def __call__(
self, self,
*args: SfgClassMember *args: (
| SfgClassComposer.ConstructorBuilder SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str
| SrcObject ),
| str,
): ):
for arg in args: for arg in args:
self._vis_block.append_member(SfgClassComposer._resolve_member(arg)) self._vis_block.append_member(SfgClassComposer._resolve_member(arg))
...@@ -430,11 +462,13 @@ class SfgClassComposer: ...@@ -430,11 +462,13 @@ class SfgClassComposer:
self._ctx.add_class(cls) self._ctx.add_class(cls)
def sequencer( def sequencer(
*args: SfgClassComposer.VisibilityContext *args: (
| SfgClassMember SfgClassComposer.VisibilityContext
| SfgClassComposer.ConstructorBuilder | SfgClassMember
| SrcObject | SfgClassComposer.ConstructorBuilder
| str, | SrcObject
| str
),
): ):
default_ended = False default_ended = False
...@@ -465,7 +499,7 @@ class SfgClassComposer: ...@@ -465,7 +499,7 @@ class SfgClassComposer:
@staticmethod @staticmethod
def _resolve_member( def _resolve_member(
arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str, arg: (SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str),
): ):
if isinstance(arg, SrcObject): if isinstance(arg, SrcObject):
return SfgMemberVariable(arg.name, arg.dtype) return SfgMemberVariable(arg.name, arg.dtype)
......
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Optional, cast from typing import TYPE_CHECKING, Optional, cast, Generator
from pystencils.typing import TypedSymbol, BasicType from pystencils.typing import TypedSymbol, BasicType
...@@ -60,10 +60,12 @@ class IntOdd(SfgCondition): ...@@ -60,10 +60,12 @@ class IntOdd(SfgCondition):
class SfgBranch(SfgCallTreeNode): class SfgBranch(SfgCallTreeNode):
def __init__(self, def __init__(
cond: SfgCondition, self,
branch_true: SfgCallTreeNode, cond: SfgCondition,
branch_false: Optional[SfgCallTreeNode] = None): branch_true: SfgCallTreeNode,
branch_false: Optional[SfgCallTreeNode] = None,
):
super().__init__(cond, branch_true, *((branch_false,) if branch_false else ())) super().__init__(cond, branch_true, *((branch_false,) if branch_false else ()))
@property @property
...@@ -89,3 +91,45 @@ class SfgBranch(SfgCallTreeNode): ...@@ -89,3 +91,45 @@ class SfgBranch(SfgCallTreeNode):
code += "\n}" code += "\n}"
return code return code
class SfgSwitch(SfgCallTreeNode):
def __init__(
self,
switch_arg: str | TypedSymbolOrObject,
cases_dict: dict[str, SfgCallTreeNode],
default: SfgCallTreeNode | None = None,
):
children = tuple(cases_dict.values()) + (
(default,) if default is not None else ()
)
super().__init__(*children)
self._switch_arg = switch_arg
self._cases_dict = cases_dict
self._default = default
@property
def switch_arg(self) -> str | TypedSymbolOrObject:
return self._switch_arg
def cases(self) -> Generator[tuple[str, SfgCallTreeNode], None, None]:
yield from self._cases_dict.items()
@property
def default(self) -> SfgCallTreeNode | None:
return self._default
def get_code(self, ctx: SfgContext) -> str:
code = f"switch({self._switch_arg}) {{\n"
for label, subtree in self._cases_dict.items():
code += f"case {label}: {{\n"
code += ctx.codestyle.indent(subtree.get_code(ctx))
code += "\nbreak;\n}\n"
if self._default is not None:
code += "default: {\n"
code += ctx.codestyle.indent(self._default.get_code(ctx))
code += "\nbreak;\n}\n"
code += "}"
return code
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment