From be8dd46e126adad8c70f45d60527a5f9e30004bd Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 8 Mar 2024 16:26:01 +0100
Subject: [PATCH] extended switch/case support

---
 .flake8                                       |   1 +
 src/pystencilssfg/__init__.py                 |   7 +-
 src/pystencilssfg/__main__.py                 |   2 +-
 src/pystencilssfg/_version.py                 | 157 +++++++++++-------
 src/pystencilssfg/cli.py                      |  62 ++++---
 src/pystencilssfg/composer/__init__.py        |   7 +-
 src/pystencilssfg/composer/basic_composer.py  |  12 +-
 src/pystencilssfg/composer/class_composer.py  |   2 +-
 src/pystencilssfg/composer/custom.py          |   3 +-
 src/pystencilssfg/emission/__init__.py        |   4 +-
 src/pystencilssfg/emission/printers.py        |   2 +-
 src/pystencilssfg/exceptions.py               |   1 -
 src/pystencilssfg/source_components.py        |  23 ++-
 src/pystencilssfg/source_concepts/__init__.py |   4 +-
 .../source_concepts/cpp/std_mdspan.py         |  86 ++++++----
 src/pystencilssfg/tree/__init__.py            |   2 +-
 src/pystencilssfg/tree/basic_nodes.py         |   3 +-
 src/pystencilssfg/tree/deferred_nodes.py      |  20 ++-
 src/pystencilssfg/types.py                    |   2 +-
 src/pystencilssfg/visitors/tree_visitors.py   |  10 +-
 20 files changed, 252 insertions(+), 158 deletions(-)

diff --git a/.flake8 b/.flake8
index aa079ec..313ba09 100644
--- a/.flake8
+++ b/.flake8
@@ -1,2 +1,3 @@
 [flake8]
 max-line-length=120
+exclude = src/pystencilssfg/_version.py
diff --git a/src/pystencilssfg/__init__.py b/src/pystencilssfg/__init__.py
index 48de907..457c89c 100644
--- a/src/pystencilssfg/__init__.py
+++ b/src/pystencilssfg/__init__.py
@@ -3,9 +3,8 @@ from .generator import SourceFileGenerator
 from .composer import SfgComposer
 from .context import SfgContext
 
-__all__ = [
-    "SourceFileGenerator", "SfgComposer", "SfgConfiguration", "SfgContext"
-]
+__all__ = ["SourceFileGenerator", "SfgComposer", "SfgConfiguration", "SfgContext"]
 
 from . import _version
-__version__ = _version.get_versions()['version']
+
+__version__ = _version.get_versions()["version"]
diff --git a/src/pystencilssfg/__main__.py b/src/pystencilssfg/__main__.py
index ce41802..438005a 100644
--- a/src/pystencilssfg/__main__.py
+++ b/src/pystencilssfg/__main__.py
@@ -1,4 +1,4 @@
-
 if __name__ == "__main__":
     from .cli import cli_main
+
     cli_main("python -m pystencilssfg")
diff --git a/src/pystencilssfg/_version.py b/src/pystencilssfg/_version.py
index a5947d7..c9dee16 100644
--- a/src/pystencilssfg/_version.py
+++ b/src/pystencilssfg/_version.py
@@ -1,4 +1,3 @@
-
 # This file helps to compute a version number in source trees obtained from
 # git-archive tarball (such as those provided by githubs download-from-tag
 # feature). Distribution tarballs (built by setup.py sdist) and build
@@ -68,12 +67,14 @@ HANDLERS: Dict[str, Dict[str, Callable]] = {}
 
 def register_vcs_handler(vcs: str, method: str) -> Callable:  # decorator
     """Create decorator to mark a method as the handler of a VCS."""
+
     def decorate(f: Callable) -> Callable:
         """Store f in HANDLERS[vcs][method]."""
         if vcs not in HANDLERS:
             HANDLERS[vcs] = {}
         HANDLERS[vcs][method] = f
         return f
+
     return decorate
 
 
@@ -100,10 +101,14 @@ def run_command(
         try:
             dispcmd = str([command] + args)
             # remember shell=False, so use git.cmd on windows, not just git
-            process = subprocess.Popen([command] + args, cwd=cwd, env=env,
-                                       stdout=subprocess.PIPE,
-                                       stderr=(subprocess.PIPE if hide_stderr
-                                               else None), **popen_kwargs)
+            process = subprocess.Popen(
+                [command] + args,
+                cwd=cwd,
+                env=env,
+                stdout=subprocess.PIPE,
+                stderr=(subprocess.PIPE if hide_stderr else None),
+                **popen_kwargs,
+            )
             break
         except OSError as e:
             if e.errno == errno.ENOENT:
@@ -141,15 +146,21 @@ def versions_from_parentdir(
     for _ in range(3):
         dirname = os.path.basename(root)
         if dirname.startswith(parentdir_prefix):
-            return {"version": dirname[len(parentdir_prefix):],
-                    "full-revisionid": None,
-                    "dirty": False, "error": None, "date": None}
+            return {
+                "version": dirname[len(parentdir_prefix) :],
+                "full-revisionid": None,
+                "dirty": False,
+                "error": None,
+                "date": None,
+            }
         rootdirs.append(root)
         root = os.path.dirname(root)  # up a level
 
     if verbose:
-        print("Tried directories %s but none started with prefix %s" %
-              (str(rootdirs), parentdir_prefix))
+        print(
+            "Tried directories %s but none started with prefix %s"
+            % (str(rootdirs), parentdir_prefix)
+        )
     raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
 
 
@@ -212,7 +223,7 @@ def git_versions_from_keywords(
     # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
     # just "foo-1.0". If we see a "tag: " prefix, prefer those.
     TAG = "tag: "
-    tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
+    tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)}
     if not tags:
         # Either we're using git < 1.8.3, or there really are no tags. We use
         # a heuristic: assume all version tags have a digit. The old git %d
@@ -221,7 +232,7 @@ def git_versions_from_keywords(
         # between branches and tags. By ignoring refnames without digits, we
         # filter out many common branch names like "release" and
         # "stabilization", as well as "HEAD" and "master".
-        tags = {r for r in refs if re.search(r'\d', r)}
+        tags = {r for r in refs if re.search(r"\d", r)}
         if verbose:
             print("discarding '%s', no digits" % ",".join(refs - tags))
     if verbose:
@@ -229,32 +240,36 @@ def git_versions_from_keywords(
     for ref in sorted(tags):
         # sorting will prefer e.g. "2.0" over "2.0rc1"
         if ref.startswith(tag_prefix):
-            r = ref[len(tag_prefix):]
+            r = ref[len(tag_prefix) :]
             # Filter out refs that exactly match prefix or that don't start
             # with a number once the prefix is stripped (mostly a concern
             # when prefix is '')
-            if not re.match(r'\d', r):
+            if not re.match(r"\d", r):
                 continue
             if verbose:
                 print("picking %s" % r)
-            return {"version": r,
-                    "full-revisionid": keywords["full"].strip(),
-                    "dirty": False, "error": None,
-                    "date": date}
+            return {
+                "version": r,
+                "full-revisionid": keywords["full"].strip(),
+                "dirty": False,
+                "error": None,
+                "date": date,
+            }
     # no suitable tags, so version is "0+unknown", but full hex is still there
     if verbose:
         print("no suitable tags, using unknown + full revision id")
-    return {"version": "0+unknown",
-            "full-revisionid": keywords["full"].strip(),
-            "dirty": False, "error": "no suitable tags", "date": None}
+    return {
+        "version": "0+unknown",
+        "full-revisionid": keywords["full"].strip(),
+        "dirty": False,
+        "error": "no suitable tags",
+        "date": None,
+    }
 
 
 @register_vcs_handler("git", "pieces_from_vcs")
 def git_pieces_from_vcs(
-    tag_prefix: str,
-    root: str,
-    verbose: bool,
-    runner: Callable = run_command
+    tag_prefix: str, root: str, verbose: bool, runner: Callable = run_command
 ) -> Dict[str, Any]:
     """Get version from 'git describe' in the root of the source tree.
 
@@ -273,8 +288,7 @@ def git_pieces_from_vcs(
     env.pop("GIT_DIR", None)
     runner = functools.partial(runner, env=env)
 
-    _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
-                   hide_stderr=not verbose)
+    _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose)
     if rc != 0:
         if verbose:
             print("Directory %s not under git control" % root)
@@ -282,10 +296,19 @@ def git_pieces_from_vcs(
 
     # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
     # if there isn't one, this yields HEX[-dirty] (no NUM)
-    describe_out, rc = runner(GITS, [
-        "describe", "--tags", "--dirty", "--always", "--long",
-        "--match", f"{tag_prefix}[[:digit:]]*"
-    ], cwd=root)
+    describe_out, rc = runner(
+        GITS,
+        [
+            "describe",
+            "--tags",
+            "--dirty",
+            "--always",
+            "--long",
+            "--match",
+            f"{tag_prefix}[[:digit:]]*",
+        ],
+        cwd=root,
+    )
     # --long was added in git-1.5.5
     if describe_out is None:
         raise NotThisMethod("'git describe' failed")
@@ -300,8 +323,7 @@ def git_pieces_from_vcs(
     pieces["short"] = full_out[:7]  # maybe improved later
     pieces["error"] = None
 
-    branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
-                             cwd=root)
+    branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root)
     # --abbrev-ref was added in git-1.6.3
     if rc != 0 or branch_name is None:
         raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
@@ -341,17 +363,16 @@ def git_pieces_from_vcs(
     dirty = git_describe.endswith("-dirty")
     pieces["dirty"] = dirty
     if dirty:
-        git_describe = git_describe[:git_describe.rindex("-dirty")]
+        git_describe = git_describe[: git_describe.rindex("-dirty")]
 
     # now we have TAG-NUM-gHEX or HEX
 
     if "-" in git_describe:
         # TAG-NUM-gHEX
-        mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
+        mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
         if not mo:
             # unparsable. Maybe git-describe is misbehaving?
-            pieces["error"] = ("unable to parse git-describe output: '%s'"
-                               % describe_out)
+            pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
             return pieces
 
         # tag
@@ -360,10 +381,12 @@ def git_pieces_from_vcs(
             if verbose:
                 fmt = "tag '%s' doesn't start with prefix '%s'"
                 print(fmt % (full_tag, tag_prefix))
-            pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
-                               % (full_tag, tag_prefix))
+            pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (
+                full_tag,
+                tag_prefix,
+            )
             return pieces
-        pieces["closest-tag"] = full_tag[len(tag_prefix):]
+        pieces["closest-tag"] = full_tag[len(tag_prefix) :]
 
         # distance: number of commits since tag
         pieces["distance"] = int(mo.group(2))
@@ -412,8 +435,7 @@ def render_pep440(pieces: Dict[str, Any]) -> str:
                 rendered += ".dirty"
     else:
         # exception #1
-        rendered = "0+untagged.%d.g%s" % (pieces["distance"],
-                                          pieces["short"])
+        rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
         if pieces["dirty"]:
             rendered += ".dirty"
     return rendered
@@ -442,8 +464,7 @@ def render_pep440_branch(pieces: Dict[str, Any]) -> str:
         rendered = "0"
         if pieces["branch"] != "master":
             rendered += ".dev0"
-        rendered += "+untagged.%d.g%s" % (pieces["distance"],
-                                          pieces["short"])
+        rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
         if pieces["dirty"]:
             rendered += ".dirty"
     return rendered
@@ -604,11 +625,13 @@ def render_git_describe_long(pieces: Dict[str, Any]) -> str:
 def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]:
     """Render the given version pieces into the requested style."""
     if pieces["error"]:
-        return {"version": "unknown",
-                "full-revisionid": pieces.get("long"),
-                "dirty": None,
-                "error": pieces["error"],
-                "date": None}
+        return {
+            "version": "unknown",
+            "full-revisionid": pieces.get("long"),
+            "dirty": None,
+            "error": pieces["error"],
+            "date": None,
+        }
 
     if not style or style == "default":
         style = "pep440"  # the default
@@ -632,9 +655,13 @@ def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]:
     else:
         raise ValueError("unknown style '%s'" % style)
 
-    return {"version": rendered, "full-revisionid": pieces["long"],
-            "dirty": pieces["dirty"], "error": None,
-            "date": pieces.get("date")}
+    return {
+        "version": rendered,
+        "full-revisionid": pieces["long"],
+        "dirty": pieces["dirty"],
+        "error": None,
+        "date": pieces.get("date"),
+    }
 
 
 def get_versions() -> Dict[str, Any]:
@@ -648,8 +675,7 @@ def get_versions() -> Dict[str, Any]:
     verbose = cfg.verbose
 
     try:
-        return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
-                                          verbose)
+        return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose)
     except NotThisMethod:
         pass
 
@@ -658,13 +684,16 @@ def get_versions() -> Dict[str, Any]:
         # versionfile_source is the relative path from the top of the source
         # tree (where the .git directory might live) to this file. Invert
         # this to find the root from __file__.
-        for _ in cfg.versionfile_source.split('/'):
+        for _ in cfg.versionfile_source.split("/"):
             root = os.path.dirname(root)
     except NameError:
-        return {"version": "0+unknown", "full-revisionid": None,
-                "dirty": None,
-                "error": "unable to find root of source tree",
-                "date": None}
+        return {
+            "version": "0+unknown",
+            "full-revisionid": None,
+            "dirty": None,
+            "error": "unable to find root of source tree",
+            "date": None,
+        }
 
     try:
         pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
@@ -678,6 +707,10 @@ def get_versions() -> Dict[str, Any]:
     except NotThisMethod:
         pass
 
-    return {"version": "0+unknown", "full-revisionid": None,
-            "dirty": None,
-            "error": "unable to compute version", "date": None}
+    return {
+        "version": "0+unknown",
+        "full-revisionid": None,
+        "dirty": None,
+        "error": "unable to compute version",
+        "date": None,
+    }
diff --git a/src/pystencilssfg/cli.py b/src/pystencilssfg/cli.py
index 9cf4d4a..e5bdb4d 100644
--- a/src/pystencilssfg/cli.py
+++ b/src/pystencilssfg/cli.py
@@ -5,19 +5,28 @@ from os import path
 from argparse import ArgumentParser, BooleanOptionalAction
 
 from .configuration import (
-    SfgConfigException, SfgConfigSource,
-    add_config_args_to_parser, config_from_parser_args, merge_configurations
+    SfgConfigException,
+    SfgConfigSource,
+    add_config_args_to_parser,
+    config_from_parser_args,
+    merge_configurations,
 )
 
 
 def add_newline_arg(parser):
-    parser.add_argument("--newline", action=BooleanOptionalAction, default=True,
-                        help="Whether to add a terminating newline to the output.")
+    parser.add_argument(
+        "--newline",
+        action=BooleanOptionalAction,
+        default=True,
+        help="Whether to add a terminating newline to the output.",
+    )
 
 
-def cli_main(program='sfg-cli'):
-    parser = ArgumentParser(program,
-                            description="pystencilssfg command-line utility for build system integration")
+def cli_main(program="sfg-cli"):
+    parser = ArgumentParser(
+        program,
+        description="pystencilssfg command-line utility for build system integration",
+    )
 
     subparsers = parser.add_subparsers(required=True, title="Subcommands")
 
@@ -26,25 +35,33 @@ def cli_main(program='sfg-cli'):
     version_parser.set_defaults(func=version)
 
     outfiles_parser = subparsers.add_parser(
-        "list-files", help="List files produced by given codegen script.")
+        "list-files", help="List files produced by given codegen script."
+    )
 
     outfiles_parser.set_defaults(func=list_files)
     add_config_args_to_parser(outfiles_parser)
     add_newline_arg(outfiles_parser)
-    outfiles_parser.add_argument("--sep", type=str, default=" ", dest="sep", help="Separator for list items")
+    outfiles_parser.add_argument(
+        "--sep", type=str, default=" ", dest="sep", help="Separator for list items"
+    )
     outfiles_parser.add_argument("codegen_script", type=str)
 
-    cmake_parser = subparsers.add_parser("cmake", help="Operations for CMake integation")
+    cmake_parser = subparsers.add_parser(
+        "cmake", help="Operations for CMake integation"
+    )
     cmake_subparsers = cmake_parser.add_subparsers(required=True)
 
     modpath = cmake_subparsers.add_parser(
-        "modulepath", help="Print the include path for the pystencils-sfg cmake module")
+        "modulepath", help="Print the include path for the pystencils-sfg cmake module"
+    )
     add_newline_arg(modpath)
     modpath.set_defaults(func=print_cmake_modulepath)
 
-    findmod = cmake_subparsers.add_parser("make-find-module",
-                                          help="Creates the pystencils-sfg CMake find module as" +
-                                          "'FindPystencilsSfg.cmake' in the current directory.")
+    findmod = cmake_subparsers.add_parser(
+        "make-find-module",
+        help="Creates the pystencils-sfg CMake find module as"
+        + "'FindPystencilsSfg.cmake' in the current directory.",
+    )
     findmod.set_defaults(func=make_cmake_find_module)
 
     args = parser.parse_args()
@@ -56,7 +73,7 @@ def cli_main(program='sfg-cli'):
 def version(args):
     from . import __version__
 
-    print(__version__, end=os.linesep if args.newline else '')
+    print(__version__, end=os.linesep if args.newline else "")
 
     exit(0)
 
@@ -76,19 +93,21 @@ def list_files(args):
 
     emitter = HeaderImplPairEmitter(config.get_output_spec(basename))
 
-    print(args.sep.join(emitter.output_files), end=os.linesep if args.newline else '')
+    print(args.sep.join(emitter.output_files), end=os.linesep if args.newline else "")
 
     exit(0)
 
 
 def print_cmake_modulepath(args):
     from .cmake import get_sfg_cmake_modulepath
-    print(get_sfg_cmake_modulepath(), end=os.linesep if args.newline else '')
+
+    print(get_sfg_cmake_modulepath(), end=os.linesep if args.newline else "")
     exit(0)
 
 
 def make_cmake_find_module(args):
     from .cmake import make_find_module
+
     make_find_module()
     exit(0)
 
@@ -100,10 +119,11 @@ def abort_with_config_exception(exception: SfgConfigException):
     match exception.config_source:
         case SfgConfigSource.PROJECT:
             eprint(
-                f"Invalid project configuration: {exception.message}\nCheck your configurator script.")
+                f"Invalid project configuration: {exception.message}\nCheck your configurator script."
+            )
         case SfgConfigSource.COMMANDLINE:
-            eprint(
-                f"Invalid configuration on command line: {exception.message}")
-        case _: assert False, "(Theoretically) unreachable code. Contact the developers."
+            eprint(f"Invalid configuration on command line: {exception.message}")
+        case _:
+            assert False, "(Theoretically) unreachable code. Contact the developers."
 
     exit(1)
diff --git a/src/pystencilssfg/composer/__init__.py b/src/pystencilssfg/composer/__init__.py
index e895fe6..c20f7de 100644
--- a/src/pystencilssfg/composer/__init__.py
+++ b/src/pystencilssfg/composer/__init__.py
@@ -2,9 +2,4 @@ from .composer import SfgComposer
 from .basic_composer import SfgBasicComposer, make_sequence
 from .class_composer import SfgClassComposer
 
-__all__ = [
-    'SfgComposer',
-    "make_sequence",
-    "SfgBasicComposer",
-    "SfgClassComposer"
-]
+__all__ = ["SfgComposer", "make_sequence", "SfgBasicComposer", "SfgClassComposer"]
diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py
index f55e2ec..1421aac 100644
--- a/src/pystencilssfg/composer/basic_composer.py
+++ b/src/pystencilssfg/composer/basic_composer.py
@@ -243,7 +243,10 @@ class SfgNodeBuilder(ABC):
         pass
 
 
-def make_sequence(*args: tuple | str | SfgCallTreeNode | SfgNodeBuilder) -> SfgSequence:
+SequencerArg = tuple | str | SfgCallTreeNode | SfgNodeBuilder
+
+
+def make_sequence(*args: SequencerArg) -> SfgSequence:
     """Construct a sequence of C++ code from various kinds of arguments.
 
     `make_sequence` is ubiquitous throughout the function building front-end;
@@ -362,13 +365,18 @@ class SfgSwitchBuilder(SfgNodeBuilder):
         if label in self._cases:
             raise SfgException(f"Duplicate case: {label}")
 
-        def sequencer(*args):
+        def sequencer(*args: SequencerArg):
             tree = make_sequence(*args)
             self._cases[label] = tree
             return self
 
         return sequencer
 
+    def cases(self, cases_dict: dict[str, SequencerArg]):
+        for key, value in cases_dict.items():
+            self.case(key)(value)
+        return self
+
     def default(self, *args):
         if self._default is not None:
             raise SfgException("Duplicate default case")
diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py
index e326292..714ecc8 100644
--- a/src/pystencilssfg/composer/class_composer.py
+++ b/src/pystencilssfg/composer/class_composer.py
@@ -218,7 +218,7 @@ class SfgClassComposer:
 
     @staticmethod
     def _resolve_member(
-        arg: (SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str),
+        arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str,
     ):
         if isinstance(arg, SrcObject):
             return SfgMemberVariable(arg.name, arg.dtype)
diff --git a/src/pystencilssfg/composer/custom.py b/src/pystencilssfg/composer/custom.py
index 1b43dd3..26e9b93 100644
--- a/src/pystencilssfg/composer/custom.py
+++ b/src/pystencilssfg/composer/custom.py
@@ -7,5 +7,4 @@ class CustomGenerator(ABC):
     [SfgComposer.generate][pystencilssfg.SfgComposer.generate]."""
 
     @abstractmethod
-    def generate(self, ctx: SfgContext) -> None:
-        ...
+    def generate(self, ctx: SfgContext) -> None: ...
diff --git a/src/pystencilssfg/emission/__init__.py b/src/pystencilssfg/emission/__init__.py
index 74314be..d7478e0 100644
--- a/src/pystencilssfg/emission/__init__.py
+++ b/src/pystencilssfg/emission/__init__.py
@@ -1,5 +1,3 @@
 from .header_impl_pair import HeaderImplPairEmitter
 
-__all__ = [
-    "HeaderImplPairEmitter"
-]
+__all__ = ["HeaderImplPairEmitter"]
diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py
index ae504f4..5c86d5f 100644
--- a/src/pystencilssfg/emission/printers.py
+++ b/src/pystencilssfg/emission/printers.py
@@ -23,7 +23,7 @@ from ..source_components import (
     SfgMemberVariable,
     SfgMethod,
     SfgVisibility,
-    SfgVisibilityBlock
+    SfgVisibilityBlock,
 )
 
 
diff --git a/src/pystencilssfg/exceptions.py b/src/pystencilssfg/exceptions.py
index 1351733..2ee025e 100644
--- a/src/pystencilssfg/exceptions.py
+++ b/src/pystencilssfg/exceptions.py
@@ -1,3 +1,2 @@
-
 class SfgException(Exception):
     pass
diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py
index 0b3d8fa..43d58a0 100644
--- a/src/pystencilssfg/source_components.py
+++ b/src/pystencilssfg/source_components.py
@@ -247,7 +247,9 @@ class SfgClassMember(ABC):
     @property
     def visibility(self) -> SfgVisibility:
         if self._visibility is None:
-            raise SfgException(f"{self} is not bound to a class and therefore has no visibility.")
+            raise SfgException(
+                f"{self} is not bound to a class and therefore has no visibility."
+            )
         return self._visibility
 
     @property
@@ -309,11 +311,7 @@ class SfgInClassDefinition(SfgClassMember):
 
 
 class SfgMemberVariable(SrcObject, SfgClassMember):
-    def __init__(
-        self,
-        name: str,
-        dtype: SrcType
-    ):
+    def __init__(self, name: str, dtype: SrcType):
         SrcObject.__init__(self, name, dtype)
         SfgClassMember.__init__(self)
 
@@ -347,7 +345,7 @@ class SfgConstructor(SfgClassMember):
         self,
         parameters: Sequence[SrcObject] = (),
         initializers: Sequence[str] = (),
-        body: str = ""
+        body: str = "",
     ):
         SfgClassMember.__init__(self)
         self._parameters = tuple(parameters)
@@ -383,6 +381,7 @@ class SfgClass:
     A more succinct interface for constructing classes is available through the
     [SfgClassComposer][pystencilssfg.composer.SfgClassComposer].
     """
+
     def __init__(
         self,
         class_name: str,
@@ -428,7 +427,8 @@ class SfgClass:
     def append_visibility_block(self, block: SfgVisibilityBlock):
         if block.visibility == SfgVisibility.DEFAULT:
             raise SfgException(
-                "Can't add another block with DEFAULT visibility to a class. Use `.default` instead.")
+                "Can't add another block with DEFAULT visibility to a class. Use `.default` instead."
+            )
 
         block._bind(self)
         for m in block.members():
@@ -442,12 +442,11 @@ class SfgClass:
         self, visibility: SfgVisibility | None = None
     ) -> Generator[SfgClassMember, None, None]:
         if visibility is None:
-            yield from chain.from_iterable(
-                b.members() for b in self._blocks
-            )
+            yield from chain.from_iterable(b.members() for b in self._blocks)
         else:
             yield from chain.from_iterable(
-                b.members() for b in filter(lambda b: b.visibility == visibility, self._blocks)
+                b.members()
+                for b in filter(lambda b: b.visibility == visibility, self._blocks)
             )
 
     def definitions(
diff --git a/src/pystencilssfg/source_concepts/__init__.py b/src/pystencilssfg/source_concepts/__init__.py
index 6375f3e..aa4817f 100644
--- a/src/pystencilssfg/source_concepts/__init__.py
+++ b/src/pystencilssfg/source_concepts/__init__.py
@@ -1,5 +1,3 @@
 from .source_objects import SrcObject, SrcField, SrcVector, TypedSymbolOrObject
 
-__all__ = [
-    "SrcObject", "SrcField", "SrcVector", "TypedSymbolOrObject"
-]
+__all__ = ["SrcObject", "SrcField", "SrcVector", "TypedSymbolOrObject"]
diff --git a/src/pystencilssfg/source_concepts/cpp/std_mdspan.py b/src/pystencilssfg/source_concepts/cpp/std_mdspan.py
index 2e57c52..11645a5 100644
--- a/src/pystencilssfg/source_concepts/cpp/std_mdspan.py
+++ b/src/pystencilssfg/source_concepts/cpp/std_mdspan.py
@@ -15,16 +15,23 @@ from ...exceptions import SfgException
 class StdMdspan(SrcField):
     dynamic_extent = "std::dynamic_extent"
 
-    def __init__(self, identifer: str,
-                 T: PsType,
-                 extents: tuple[int | str, ...],
-                 extents_type: PsType = int,
-                 reference: bool = False):
+    def __init__(
+        self,
+        identifer: str,
+        T: PsType,
+        extents: tuple[int | str, ...],
+        extents_type: PsType = int,
+        reference: bool = False,
+    ):
         cpp_typestr = cpp_typename(T)
         extents_type_str = cpp_typename(extents_type)
 
-        extents_str = f"std::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >"
-        typestring = f"std::mdspan< {cpp_typestr}, {extents_str} > {'&' if reference else ''}"
+        extents_str = (
+            f"std::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >"
+        )
+        typestring = (
+            f"std::mdspan< {cpp_typestr}, {extents_str} > {'&' if reference else ''}"
+        )
         super().__init__(identifer, SrcType(typestring))
 
         self._extents = extents
@@ -36,49 +43,61 @@ class StdMdspan(SrcField):
     def extract_ptr(self, ptr_symbol: FieldPointerSymbol):
         return SfgStatements(
             f"{ptr_symbol.dtype} {ptr_symbol.name} = {self._identifier}.data_handle();",
-            (ptr_symbol, ),
-            (self, )
+            (ptr_symbol,),
+            (self,),
         )
 
-    def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements:
+    def extract_size(
+        self, coordinate: int, size: Union[int, FieldShapeSymbol]
+    ) -> SfgStatements:
         dim = len(self._extents)
         if coordinate >= dim:
             if isinstance(size, FieldShapeSymbol):
-                raise SfgException(f"Cannot extract size in coordinate {coordinate} from a {dim}-dimensional mdspan!")
+                raise SfgException(
+                    f"Cannot extract size in coordinate {coordinate} from a {dim}-dimensional mdspan!"
+                )
             elif size != 1:
                 raise SfgException(
-                    f"Cannot map field with size {size} in coordinate {coordinate} to {dim}-dimensional mdspan!")
+                    f"Cannot map field with size {size} in coordinate {coordinate} to {dim}-dimensional mdspan!"
+                )
             else:
                 #   trivial trailing index dimensions are OK -> do nothing
-                return SfgStatements(f"// {self._identifier}.extents().extent({coordinate}) == 1", (), ())
+                return SfgStatements(
+                    f"// {self._identifier}.extents().extent({coordinate}) == 1", (), ()
+                )
 
         if isinstance(size, FieldShapeSymbol):
             return SfgStatements(
                 f"{size.dtype} {size.name} = {self._identifier}.extents().extent({coordinate});",
-                (size, ),
-                (self, )
+                (size,),
+                (self,),
             )
         else:
             return SfgStatements(
                 f"assert( {self._identifier}.extents().extent({coordinate}) == {size} );",
-                (), (self, )
+                (),
+                (self,),
             )
 
-    def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements:
+    def extract_stride(
+        self, coordinate: int, stride: Union[int, FieldStrideSymbol]
+    ) -> SfgStatements:
         if coordinate >= len(self._extents):
             raise SfgException(
-                f"Cannot extract stride in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan")
+                f"Cannot extract stride in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan"
+            )
 
         if isinstance(stride, FieldStrideSymbol):
             return SfgStatements(
                 f"{stride.dtype} {stride.name} = {self._identifier}.stride({coordinate});",
-                (stride, ),
-                (self, )
+                (stride,),
+                (self,),
             )
         else:
             return SfgStatements(
                 f"assert( {self._identifier}.stride({coordinate}) == {stride} );",
-                (), (self, )
+                (),
+                (self,),
             )
 
 
@@ -87,18 +106,29 @@ def mdspan_ref(field: Field, extents_type: type = np.uint32):
     from pystencils.field import layout_string_to_tuple
 
     if field.layout != layout_string_to_tuple("soa", field.spatial_dimensions):
-        raise NotImplementedError("mdspan mapping is currently only available for structure-of-arrays fields")
+        raise NotImplementedError(
+            "mdspan mapping is currently only available for structure-of-arrays fields"
+        )
 
     extents: list[str | int] = []
 
     for s in field.spatial_shape:
-        extents.append(StdMdspan.dynamic_extent if isinstance(s, FieldShapeSymbol) else cast(int, s))
+        extents.append(
+            StdMdspan.dynamic_extent
+            if isinstance(s, FieldShapeSymbol)
+            else cast(int, s)
+        )
 
     if field.index_shape != (1,):
         for s in field.index_shape:
-            extents += StdMdspan.dynamic_extent if isinstance(s, FieldShapeSymbol) else s
+            extents += (
+                StdMdspan.dynamic_extent if isinstance(s, FieldShapeSymbol) else s
+            )
 
-    return StdMdspan(field.name, field.dtype,
-                     tuple(extents),
-                     extents_type=extents_type,
-                     reference=True)
+    return StdMdspan(
+        field.name,
+        field.dtype,
+        tuple(extents),
+        extents_type=extents_type,
+        reference=True,
+    )
diff --git a/src/pystencilssfg/tree/__init__.py b/src/pystencilssfg/tree/__init__.py
index 46d42f8..15cd329 100644
--- a/src/pystencilssfg/tree/__init__.py
+++ b/src/pystencilssfg/tree/__init__.py
@@ -7,7 +7,7 @@ from .basic_nodes import (
     SfgSequence,
     SfgStatements,
     SfgFunctionParams,
-    SfgRequireIncludes
+    SfgRequireIncludes,
 )
 from .conditional import SfgBranch, SfgCondition, IntEven, IntOdd
 
diff --git a/src/pystencilssfg/tree/basic_nodes.py b/src/pystencilssfg/tree/basic_nodes.py
index e209e0a..26fe8ed 100644
--- a/src/pystencilssfg/tree/basic_nodes.py
+++ b/src/pystencilssfg/tree/basic_nodes.py
@@ -78,8 +78,7 @@ class SfgCallTreeLeaf(SfgCallTreeNode, ABC):
 
     @property
     @abstractmethod
-    def required_parameters(self) -> set[TypedSymbolOrObject]:
-        ...
+    def required_parameters(self) -> set[TypedSymbolOrObject]: ...
 
 
 class SfgEmptyNode(SfgCallTreeLeaf):
diff --git a/src/pystencilssfg/tree/deferred_nodes.py b/src/pystencilssfg/tree/deferred_nodes.py
index bb3cab4..040b51b 100644
--- a/src/pystencilssfg/tree/deferred_nodes.py
+++ b/src/pystencilssfg/tree/deferred_nodes.py
@@ -27,19 +27,22 @@ class SfgDeferredNode(SfgCallTreeNode, ABC):
 
     class InvalidAccess:
         def __get__(self):
-            raise SfgException("Invalid access into deferred node; deferred nodes must be expanded first.")
+            raise SfgException(
+                "Invalid access into deferred node; deferred nodes must be expanded first."
+            )
 
     def __init__(self):
         self._children = SfgDeferredNode.InvalidAccess
 
     def get_code(self, ctx: SfgContext) -> str:
-        raise SfgException("Invalid access into deferred node; deferred nodes must be expanded first.")
+        raise SfgException(
+            "Invalid access into deferred node; deferred nodes must be expanded first."
+        )
 
 
 class SfgParamCollectionDeferredNode(SfgDeferredNode, ABC):
     @abstractmethod
-    def expand(self, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode:
-        ...
+    def expand(self, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode: ...
 
 
 class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode):
@@ -51,9 +54,14 @@ class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode):
         #    Find field pointer
         ptr = None
         for param in visible_params:
-            if isinstance(param, FieldPointerSymbol) and param.field_name == self._field.name:
+            if (
+                isinstance(param, FieldPointerSymbol)
+                and param.field_name == self._field.name
+            ):
                 if param.dtype.base_type != self._field.dtype:
-                    raise SfgException("Data type mismatch between field and encountered pointer symbol")
+                    raise SfgException(
+                        "Data type mismatch between field and encountered pointer symbol"
+                    )
                 ptr = param
 
         #   Find required sizes
diff --git a/src/pystencilssfg/types.py b/src/pystencilssfg/types.py
index 250b0cc..4936798 100644
--- a/src/pystencilssfg/types.py
+++ b/src/pystencilssfg/types.py
@@ -19,7 +19,7 @@ PsType is a temporary solution and will be removed in the future
 in favor of the consolidated pystencils backend typing system.
 """
 
-SrcType = NewType('SrcType', str)
+SrcType = NewType("SrcType", str)
 """C/C++-Types occuring during source file generation.
 
 When necessary, the SFG package checks equality of types by their name strings; it does
diff --git a/src/pystencilssfg/visitors/tree_visitors.py b/src/pystencilssfg/visitors/tree_visitors.py
index 76cb9c5..bb596b5 100644
--- a/src/pystencilssfg/visitors/tree_visitors.py
+++ b/src/pystencilssfg/visitors/tree_visitors.py
@@ -11,8 +11,9 @@ from ..tree.basic_nodes import (
     SfgStatements,
 )
 from ..tree.deferred_nodes import SfgParamCollectionDeferredNode
+from ..tree.conditional import SfgSwitch
 from .dispatcher import visitor
-from ..source_concepts.source_objects import TypedSymbolOrObject
+from ..source_concepts.source_objects import TypedSymbol, SrcObject, TypedSymbolOrObject
 
 
 class FlattenSequences:
@@ -58,6 +59,13 @@ class ExpandingParameterCollector:
     def leaf(self, leaf: SfgCallTreeLeaf) -> set[TypedSymbolOrObject]:
         return leaf.required_parameters
 
+    @visit.case(SfgSwitch)
+    def switch(self, sw: SfgSwitch) -> set[TypedSymbolOrObject]:
+        params = self.branching_node(sw)
+        if isinstance(sw.switch_arg, (TypedSymbol, SrcObject)):
+            params.add(sw.switch_arg)
+        return params
+
     @visit.case(SfgSequence)
     def sequence(self, sequence: SfgSequence) -> set[TypedSymbolOrObject]:
         """
-- 
GitLab