From be8dd46e126adad8c70f45d60527a5f9e30004bd Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 8 Mar 2024 16:26:01 +0100 Subject: [PATCH] extended switch/case support --- .flake8 | 1 + src/pystencilssfg/__init__.py | 7 +- src/pystencilssfg/__main__.py | 2 +- src/pystencilssfg/_version.py | 157 +++++++++++------- src/pystencilssfg/cli.py | 62 ++++--- src/pystencilssfg/composer/__init__.py | 7 +- src/pystencilssfg/composer/basic_composer.py | 12 +- src/pystencilssfg/composer/class_composer.py | 2 +- src/pystencilssfg/composer/custom.py | 3 +- src/pystencilssfg/emission/__init__.py | 4 +- src/pystencilssfg/emission/printers.py | 2 +- src/pystencilssfg/exceptions.py | 1 - src/pystencilssfg/source_components.py | 23 ++- src/pystencilssfg/source_concepts/__init__.py | 4 +- .../source_concepts/cpp/std_mdspan.py | 86 ++++++---- src/pystencilssfg/tree/__init__.py | 2 +- src/pystencilssfg/tree/basic_nodes.py | 3 +- src/pystencilssfg/tree/deferred_nodes.py | 20 ++- src/pystencilssfg/types.py | 2 +- src/pystencilssfg/visitors/tree_visitors.py | 10 +- 20 files changed, 252 insertions(+), 158 deletions(-) diff --git a/.flake8 b/.flake8 index aa079ec..313ba09 100644 --- a/.flake8 +++ b/.flake8 @@ -1,2 +1,3 @@ [flake8] max-line-length=120 +exclude = src/pystencilssfg/_version.py diff --git a/src/pystencilssfg/__init__.py b/src/pystencilssfg/__init__.py index 48de907..457c89c 100644 --- a/src/pystencilssfg/__init__.py +++ b/src/pystencilssfg/__init__.py @@ -3,9 +3,8 @@ from .generator import SourceFileGenerator from .composer import SfgComposer from .context import SfgContext -__all__ = [ - "SourceFileGenerator", "SfgComposer", "SfgConfiguration", "SfgContext" -] +__all__ = ["SourceFileGenerator", "SfgComposer", "SfgConfiguration", "SfgContext"] from . import _version -__version__ = _version.get_versions()['version'] + +__version__ = _version.get_versions()["version"] diff --git a/src/pystencilssfg/__main__.py b/src/pystencilssfg/__main__.py index ce41802..438005a 100644 --- a/src/pystencilssfg/__main__.py +++ b/src/pystencilssfg/__main__.py @@ -1,4 +1,4 @@ - if __name__ == "__main__": from .cli import cli_main + cli_main("python -m pystencilssfg") diff --git a/src/pystencilssfg/_version.py b/src/pystencilssfg/_version.py index a5947d7..c9dee16 100644 --- a/src/pystencilssfg/_version.py +++ b/src/pystencilssfg/_version.py @@ -1,4 +1,3 @@ - # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -68,12 +67,14 @@ HANDLERS: Dict[str, Dict[str, Callable]] = {} def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate @@ -100,10 +101,14 @@ def run_command( try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, + ) break except OSError as e: if e.errno == errno.ENOENT: @@ -141,15 +146,21 @@ def versions_from_parentdir( for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -212,7 +223,7 @@ def git_versions_from_keywords( # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -221,7 +232,7 @@ def git_versions_from_keywords( # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -229,32 +240,36 @@ def git_versions_from_keywords( for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') - if not re.match(r'\d', r): + if not re.match(r"\d", r): continue if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") def git_pieces_from_vcs( - tag_prefix: str, - root: str, - verbose: bool, - runner: Callable = run_command + tag_prefix: str, root: str, verbose: bool, runner: Callable = run_command ) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. @@ -273,8 +288,7 @@ def git_pieces_from_vcs( env.pop("GIT_DIR", None) runner = functools.partial(runner, env=env) - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=not verbose) + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -282,10 +296,19 @@ def git_pieces_from_vcs( # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + f"{tag_prefix}[[:digit:]]*", + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -300,8 +323,7 @@ def git_pieces_from_vcs( pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") @@ -341,17 +363,16 @@ def git_pieces_from_vcs( dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -360,10 +381,12 @@ def git_pieces_from_vcs( if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -412,8 +435,7 @@ def render_pep440(pieces: Dict[str, Any]) -> str: rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -442,8 +464,7 @@ def render_pep440_branch(pieces: Dict[str, Any]) -> str: rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -604,11 +625,13 @@ def render_git_describe_long(pieces: Dict[str, Any]) -> str: def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -632,9 +655,13 @@ def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions() -> Dict[str, Any]: @@ -648,8 +675,7 @@ def get_versions() -> Dict[str, Any]: verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -658,13 +684,16 @@ def get_versions() -> Dict[str, Any]: # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -678,6 +707,10 @@ def get_versions() -> Dict[str, Any]: except NotThisMethod: pass - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/src/pystencilssfg/cli.py b/src/pystencilssfg/cli.py index 9cf4d4a..e5bdb4d 100644 --- a/src/pystencilssfg/cli.py +++ b/src/pystencilssfg/cli.py @@ -5,19 +5,28 @@ from os import path from argparse import ArgumentParser, BooleanOptionalAction from .configuration import ( - SfgConfigException, SfgConfigSource, - add_config_args_to_parser, config_from_parser_args, merge_configurations + SfgConfigException, + SfgConfigSource, + add_config_args_to_parser, + config_from_parser_args, + merge_configurations, ) def add_newline_arg(parser): - parser.add_argument("--newline", action=BooleanOptionalAction, default=True, - help="Whether to add a terminating newline to the output.") + parser.add_argument( + "--newline", + action=BooleanOptionalAction, + default=True, + help="Whether to add a terminating newline to the output.", + ) -def cli_main(program='sfg-cli'): - parser = ArgumentParser(program, - description="pystencilssfg command-line utility for build system integration") +def cli_main(program="sfg-cli"): + parser = ArgumentParser( + program, + description="pystencilssfg command-line utility for build system integration", + ) subparsers = parser.add_subparsers(required=True, title="Subcommands") @@ -26,25 +35,33 @@ def cli_main(program='sfg-cli'): version_parser.set_defaults(func=version) outfiles_parser = subparsers.add_parser( - "list-files", help="List files produced by given codegen script.") + "list-files", help="List files produced by given codegen script." + ) outfiles_parser.set_defaults(func=list_files) add_config_args_to_parser(outfiles_parser) add_newline_arg(outfiles_parser) - outfiles_parser.add_argument("--sep", type=str, default=" ", dest="sep", help="Separator for list items") + outfiles_parser.add_argument( + "--sep", type=str, default=" ", dest="sep", help="Separator for list items" + ) outfiles_parser.add_argument("codegen_script", type=str) - cmake_parser = subparsers.add_parser("cmake", help="Operations for CMake integation") + cmake_parser = subparsers.add_parser( + "cmake", help="Operations for CMake integation" + ) cmake_subparsers = cmake_parser.add_subparsers(required=True) modpath = cmake_subparsers.add_parser( - "modulepath", help="Print the include path for the pystencils-sfg cmake module") + "modulepath", help="Print the include path for the pystencils-sfg cmake module" + ) add_newline_arg(modpath) modpath.set_defaults(func=print_cmake_modulepath) - findmod = cmake_subparsers.add_parser("make-find-module", - help="Creates the pystencils-sfg CMake find module as" + - "'FindPystencilsSfg.cmake' in the current directory.") + findmod = cmake_subparsers.add_parser( + "make-find-module", + help="Creates the pystencils-sfg CMake find module as" + + "'FindPystencilsSfg.cmake' in the current directory.", + ) findmod.set_defaults(func=make_cmake_find_module) args = parser.parse_args() @@ -56,7 +73,7 @@ def cli_main(program='sfg-cli'): def version(args): from . import __version__ - print(__version__, end=os.linesep if args.newline else '') + print(__version__, end=os.linesep if args.newline else "") exit(0) @@ -76,19 +93,21 @@ def list_files(args): emitter = HeaderImplPairEmitter(config.get_output_spec(basename)) - print(args.sep.join(emitter.output_files), end=os.linesep if args.newline else '') + print(args.sep.join(emitter.output_files), end=os.linesep if args.newline else "") exit(0) def print_cmake_modulepath(args): from .cmake import get_sfg_cmake_modulepath - print(get_sfg_cmake_modulepath(), end=os.linesep if args.newline else '') + + print(get_sfg_cmake_modulepath(), end=os.linesep if args.newline else "") exit(0) def make_cmake_find_module(args): from .cmake import make_find_module + make_find_module() exit(0) @@ -100,10 +119,11 @@ def abort_with_config_exception(exception: SfgConfigException): match exception.config_source: case SfgConfigSource.PROJECT: eprint( - f"Invalid project configuration: {exception.message}\nCheck your configurator script.") + f"Invalid project configuration: {exception.message}\nCheck your configurator script." + ) case SfgConfigSource.COMMANDLINE: - eprint( - f"Invalid configuration on command line: {exception.message}") - case _: assert False, "(Theoretically) unreachable code. Contact the developers." + eprint(f"Invalid configuration on command line: {exception.message}") + case _: + assert False, "(Theoretically) unreachable code. Contact the developers." exit(1) diff --git a/src/pystencilssfg/composer/__init__.py b/src/pystencilssfg/composer/__init__.py index e895fe6..c20f7de 100644 --- a/src/pystencilssfg/composer/__init__.py +++ b/src/pystencilssfg/composer/__init__.py @@ -2,9 +2,4 @@ from .composer import SfgComposer from .basic_composer import SfgBasicComposer, make_sequence from .class_composer import SfgClassComposer -__all__ = [ - 'SfgComposer', - "make_sequence", - "SfgBasicComposer", - "SfgClassComposer" -] +__all__ = ["SfgComposer", "make_sequence", "SfgBasicComposer", "SfgClassComposer"] diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index f55e2ec..1421aac 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -243,7 +243,10 @@ class SfgNodeBuilder(ABC): pass -def make_sequence(*args: tuple | str | SfgCallTreeNode | SfgNodeBuilder) -> SfgSequence: +SequencerArg = tuple | str | SfgCallTreeNode | SfgNodeBuilder + + +def make_sequence(*args: SequencerArg) -> SfgSequence: """Construct a sequence of C++ code from various kinds of arguments. `make_sequence` is ubiquitous throughout the function building front-end; @@ -362,13 +365,18 @@ class SfgSwitchBuilder(SfgNodeBuilder): if label in self._cases: raise SfgException(f"Duplicate case: {label}") - def sequencer(*args): + def sequencer(*args: SequencerArg): tree = make_sequence(*args) self._cases[label] = tree return self return sequencer + def cases(self, cases_dict: dict[str, SequencerArg]): + for key, value in cases_dict.items(): + self.case(key)(value) + return self + def default(self, *args): if self._default is not None: raise SfgException("Duplicate default case") diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index e326292..714ecc8 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -218,7 +218,7 @@ class SfgClassComposer: @staticmethod def _resolve_member( - arg: (SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str), + arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str, ): if isinstance(arg, SrcObject): return SfgMemberVariable(arg.name, arg.dtype) diff --git a/src/pystencilssfg/composer/custom.py b/src/pystencilssfg/composer/custom.py index 1b43dd3..26e9b93 100644 --- a/src/pystencilssfg/composer/custom.py +++ b/src/pystencilssfg/composer/custom.py @@ -7,5 +7,4 @@ class CustomGenerator(ABC): [SfgComposer.generate][pystencilssfg.SfgComposer.generate].""" @abstractmethod - def generate(self, ctx: SfgContext) -> None: - ... + def generate(self, ctx: SfgContext) -> None: ... diff --git a/src/pystencilssfg/emission/__init__.py b/src/pystencilssfg/emission/__init__.py index 74314be..d7478e0 100644 --- a/src/pystencilssfg/emission/__init__.py +++ b/src/pystencilssfg/emission/__init__.py @@ -1,5 +1,3 @@ from .header_impl_pair import HeaderImplPairEmitter -__all__ = [ - "HeaderImplPairEmitter" -] +__all__ = ["HeaderImplPairEmitter"] diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py index ae504f4..5c86d5f 100644 --- a/src/pystencilssfg/emission/printers.py +++ b/src/pystencilssfg/emission/printers.py @@ -23,7 +23,7 @@ from ..source_components import ( SfgMemberVariable, SfgMethod, SfgVisibility, - SfgVisibilityBlock + SfgVisibilityBlock, ) diff --git a/src/pystencilssfg/exceptions.py b/src/pystencilssfg/exceptions.py index 1351733..2ee025e 100644 --- a/src/pystencilssfg/exceptions.py +++ b/src/pystencilssfg/exceptions.py @@ -1,3 +1,2 @@ - class SfgException(Exception): pass diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py index 0b3d8fa..43d58a0 100644 --- a/src/pystencilssfg/source_components.py +++ b/src/pystencilssfg/source_components.py @@ -247,7 +247,9 @@ class SfgClassMember(ABC): @property def visibility(self) -> SfgVisibility: if self._visibility is None: - raise SfgException(f"{self} is not bound to a class and therefore has no visibility.") + raise SfgException( + f"{self} is not bound to a class and therefore has no visibility." + ) return self._visibility @property @@ -309,11 +311,7 @@ class SfgInClassDefinition(SfgClassMember): class SfgMemberVariable(SrcObject, SfgClassMember): - def __init__( - self, - name: str, - dtype: SrcType - ): + def __init__(self, name: str, dtype: SrcType): SrcObject.__init__(self, name, dtype) SfgClassMember.__init__(self) @@ -347,7 +345,7 @@ class SfgConstructor(SfgClassMember): self, parameters: Sequence[SrcObject] = (), initializers: Sequence[str] = (), - body: str = "" + body: str = "", ): SfgClassMember.__init__(self) self._parameters = tuple(parameters) @@ -383,6 +381,7 @@ class SfgClass: A more succinct interface for constructing classes is available through the [SfgClassComposer][pystencilssfg.composer.SfgClassComposer]. """ + def __init__( self, class_name: str, @@ -428,7 +427,8 @@ class SfgClass: def append_visibility_block(self, block: SfgVisibilityBlock): if block.visibility == SfgVisibility.DEFAULT: raise SfgException( - "Can't add another block with DEFAULT visibility to a class. Use `.default` instead.") + "Can't add another block with DEFAULT visibility to a class. Use `.default` instead." + ) block._bind(self) for m in block.members(): @@ -442,12 +442,11 @@ class SfgClass: self, visibility: SfgVisibility | None = None ) -> Generator[SfgClassMember, None, None]: if visibility is None: - yield from chain.from_iterable( - b.members() for b in self._blocks - ) + yield from chain.from_iterable(b.members() for b in self._blocks) else: yield from chain.from_iterable( - b.members() for b in filter(lambda b: b.visibility == visibility, self._blocks) + b.members() + for b in filter(lambda b: b.visibility == visibility, self._blocks) ) def definitions( diff --git a/src/pystencilssfg/source_concepts/__init__.py b/src/pystencilssfg/source_concepts/__init__.py index 6375f3e..aa4817f 100644 --- a/src/pystencilssfg/source_concepts/__init__.py +++ b/src/pystencilssfg/source_concepts/__init__.py @@ -1,5 +1,3 @@ from .source_objects import SrcObject, SrcField, SrcVector, TypedSymbolOrObject -__all__ = [ - "SrcObject", "SrcField", "SrcVector", "TypedSymbolOrObject" -] +__all__ = ["SrcObject", "SrcField", "SrcVector", "TypedSymbolOrObject"] diff --git a/src/pystencilssfg/source_concepts/cpp/std_mdspan.py b/src/pystencilssfg/source_concepts/cpp/std_mdspan.py index 2e57c52..11645a5 100644 --- a/src/pystencilssfg/source_concepts/cpp/std_mdspan.py +++ b/src/pystencilssfg/source_concepts/cpp/std_mdspan.py @@ -15,16 +15,23 @@ from ...exceptions import SfgException class StdMdspan(SrcField): dynamic_extent = "std::dynamic_extent" - def __init__(self, identifer: str, - T: PsType, - extents: tuple[int | str, ...], - extents_type: PsType = int, - reference: bool = False): + def __init__( + self, + identifer: str, + T: PsType, + extents: tuple[int | str, ...], + extents_type: PsType = int, + reference: bool = False, + ): cpp_typestr = cpp_typename(T) extents_type_str = cpp_typename(extents_type) - extents_str = f"std::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >" - typestring = f"std::mdspan< {cpp_typestr}, {extents_str} > {'&' if reference else ''}" + extents_str = ( + f"std::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >" + ) + typestring = ( + f"std::mdspan< {cpp_typestr}, {extents_str} > {'&' if reference else ''}" + ) super().__init__(identifer, SrcType(typestring)) self._extents = extents @@ -36,49 +43,61 @@ class StdMdspan(SrcField): def extract_ptr(self, ptr_symbol: FieldPointerSymbol): return SfgStatements( f"{ptr_symbol.dtype} {ptr_symbol.name} = {self._identifier}.data_handle();", - (ptr_symbol, ), - (self, ) + (ptr_symbol,), + (self,), ) - def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements: + def extract_size( + self, coordinate: int, size: Union[int, FieldShapeSymbol] + ) -> SfgStatements: dim = len(self._extents) if coordinate >= dim: if isinstance(size, FieldShapeSymbol): - raise SfgException(f"Cannot extract size in coordinate {coordinate} from a {dim}-dimensional mdspan!") + raise SfgException( + f"Cannot extract size in coordinate {coordinate} from a {dim}-dimensional mdspan!" + ) elif size != 1: raise SfgException( - f"Cannot map field with size {size} in coordinate {coordinate} to {dim}-dimensional mdspan!") + f"Cannot map field with size {size} in coordinate {coordinate} to {dim}-dimensional mdspan!" + ) else: # trivial trailing index dimensions are OK -> do nothing - return SfgStatements(f"// {self._identifier}.extents().extent({coordinate}) == 1", (), ()) + return SfgStatements( + f"// {self._identifier}.extents().extent({coordinate}) == 1", (), () + ) if isinstance(size, FieldShapeSymbol): return SfgStatements( f"{size.dtype} {size.name} = {self._identifier}.extents().extent({coordinate});", - (size, ), - (self, ) + (size,), + (self,), ) else: return SfgStatements( f"assert( {self._identifier}.extents().extent({coordinate}) == {size} );", - (), (self, ) + (), + (self,), ) - def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements: + def extract_stride( + self, coordinate: int, stride: Union[int, FieldStrideSymbol] + ) -> SfgStatements: if coordinate >= len(self._extents): raise SfgException( - f"Cannot extract stride in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan") + f"Cannot extract stride in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan" + ) if isinstance(stride, FieldStrideSymbol): return SfgStatements( f"{stride.dtype} {stride.name} = {self._identifier}.stride({coordinate});", - (stride, ), - (self, ) + (stride,), + (self,), ) else: return SfgStatements( f"assert( {self._identifier}.stride({coordinate}) == {stride} );", - (), (self, ) + (), + (self,), ) @@ -87,18 +106,29 @@ def mdspan_ref(field: Field, extents_type: type = np.uint32): from pystencils.field import layout_string_to_tuple if field.layout != layout_string_to_tuple("soa", field.spatial_dimensions): - raise NotImplementedError("mdspan mapping is currently only available for structure-of-arrays fields") + raise NotImplementedError( + "mdspan mapping is currently only available for structure-of-arrays fields" + ) extents: list[str | int] = [] for s in field.spatial_shape: - extents.append(StdMdspan.dynamic_extent if isinstance(s, FieldShapeSymbol) else cast(int, s)) + extents.append( + StdMdspan.dynamic_extent + if isinstance(s, FieldShapeSymbol) + else cast(int, s) + ) if field.index_shape != (1,): for s in field.index_shape: - extents += StdMdspan.dynamic_extent if isinstance(s, FieldShapeSymbol) else s + extents += ( + StdMdspan.dynamic_extent if isinstance(s, FieldShapeSymbol) else s + ) - return StdMdspan(field.name, field.dtype, - tuple(extents), - extents_type=extents_type, - reference=True) + return StdMdspan( + field.name, + field.dtype, + tuple(extents), + extents_type=extents_type, + reference=True, + ) diff --git a/src/pystencilssfg/tree/__init__.py b/src/pystencilssfg/tree/__init__.py index 46d42f8..15cd329 100644 --- a/src/pystencilssfg/tree/__init__.py +++ b/src/pystencilssfg/tree/__init__.py @@ -7,7 +7,7 @@ from .basic_nodes import ( SfgSequence, SfgStatements, SfgFunctionParams, - SfgRequireIncludes + SfgRequireIncludes, ) from .conditional import SfgBranch, SfgCondition, IntEven, IntOdd diff --git a/src/pystencilssfg/tree/basic_nodes.py b/src/pystencilssfg/tree/basic_nodes.py index e209e0a..26fe8ed 100644 --- a/src/pystencilssfg/tree/basic_nodes.py +++ b/src/pystencilssfg/tree/basic_nodes.py @@ -78,8 +78,7 @@ class SfgCallTreeLeaf(SfgCallTreeNode, ABC): @property @abstractmethod - def required_parameters(self) -> set[TypedSymbolOrObject]: - ... + def required_parameters(self) -> set[TypedSymbolOrObject]: ... class SfgEmptyNode(SfgCallTreeLeaf): diff --git a/src/pystencilssfg/tree/deferred_nodes.py b/src/pystencilssfg/tree/deferred_nodes.py index bb3cab4..040b51b 100644 --- a/src/pystencilssfg/tree/deferred_nodes.py +++ b/src/pystencilssfg/tree/deferred_nodes.py @@ -27,19 +27,22 @@ class SfgDeferredNode(SfgCallTreeNode, ABC): class InvalidAccess: def __get__(self): - raise SfgException("Invalid access into deferred node; deferred nodes must be expanded first.") + raise SfgException( + "Invalid access into deferred node; deferred nodes must be expanded first." + ) def __init__(self): self._children = SfgDeferredNode.InvalidAccess def get_code(self, ctx: SfgContext) -> str: - raise SfgException("Invalid access into deferred node; deferred nodes must be expanded first.") + raise SfgException( + "Invalid access into deferred node; deferred nodes must be expanded first." + ) class SfgParamCollectionDeferredNode(SfgDeferredNode, ABC): @abstractmethod - def expand(self, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode: - ... + def expand(self, visible_params: set[TypedSymbolOrObject]) -> SfgCallTreeNode: ... class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode): @@ -51,9 +54,14 @@ class SfgDeferredFieldMapping(SfgParamCollectionDeferredNode): # Find field pointer ptr = None for param in visible_params: - if isinstance(param, FieldPointerSymbol) and param.field_name == self._field.name: + if ( + isinstance(param, FieldPointerSymbol) + and param.field_name == self._field.name + ): if param.dtype.base_type != self._field.dtype: - raise SfgException("Data type mismatch between field and encountered pointer symbol") + raise SfgException( + "Data type mismatch between field and encountered pointer symbol" + ) ptr = param # Find required sizes diff --git a/src/pystencilssfg/types.py b/src/pystencilssfg/types.py index 250b0cc..4936798 100644 --- a/src/pystencilssfg/types.py +++ b/src/pystencilssfg/types.py @@ -19,7 +19,7 @@ PsType is a temporary solution and will be removed in the future in favor of the consolidated pystencils backend typing system. """ -SrcType = NewType('SrcType', str) +SrcType = NewType("SrcType", str) """C/C++-Types occuring during source file generation. When necessary, the SFG package checks equality of types by their name strings; it does diff --git a/src/pystencilssfg/visitors/tree_visitors.py b/src/pystencilssfg/visitors/tree_visitors.py index 76cb9c5..bb596b5 100644 --- a/src/pystencilssfg/visitors/tree_visitors.py +++ b/src/pystencilssfg/visitors/tree_visitors.py @@ -11,8 +11,9 @@ from ..tree.basic_nodes import ( SfgStatements, ) from ..tree.deferred_nodes import SfgParamCollectionDeferredNode +from ..tree.conditional import SfgSwitch from .dispatcher import visitor -from ..source_concepts.source_objects import TypedSymbolOrObject +from ..source_concepts.source_objects import TypedSymbol, SrcObject, TypedSymbolOrObject class FlattenSequences: @@ -58,6 +59,13 @@ class ExpandingParameterCollector: def leaf(self, leaf: SfgCallTreeLeaf) -> set[TypedSymbolOrObject]: return leaf.required_parameters + @visit.case(SfgSwitch) + def switch(self, sw: SfgSwitch) -> set[TypedSymbolOrObject]: + params = self.branching_node(sw) + if isinstance(sw.switch_arg, (TypedSymbol, SrcObject)): + params.add(sw.switch_arg) + return params + @visit.case(SfgSequence) def sequence(self, sequence: SfgSequence) -> set[TypedSymbolOrObject]: """ -- GitLab