diff --git a/src/pystencilssfg/cli.py b/src/pystencilssfg/cli.py index 3b321c2d0be13b906384701080b8de870647d507..8cd9f4f552014d3ef4770e63c961bbe8680dcdc9 100644 --- a/src/pystencilssfg/cli.py +++ b/src/pystencilssfg/cli.py @@ -1,11 +1,12 @@ import sys import os from os import path +from pathlib import Path +from typing import NoReturn from argparse import ArgumentParser, BooleanOptionalAction from .config import CommandLineParameters, SfgConfigException, OutputMode -from .emission import OutputSpec def add_newline_arg(parser): @@ -17,7 +18,7 @@ def add_newline_arg(parser): ) -def cli_main(program="sfg-cli"): +def cli_main(program="sfg-cli") -> NoReturn: parser = ArgumentParser( program, description="pystencilssfg command-line utility for build system integration", @@ -65,7 +66,7 @@ def cli_main(program="sfg-cli"): exit(-1) # should never happen -def version(args): +def version(args) -> NoReturn: from . import __version__ print(__version__, end=os.linesep if args.newline else "") @@ -73,37 +74,43 @@ def version(args): exit(0) -def list_files(args): +def list_files(args) -> NoReturn: cli_params = CommandLineParameters(args) config = cli_params.get_config() _, scriptname = path.split(args.codegen_script) basename = path.splitext(scriptname)[0] - output_spec = OutputSpec.create(config, basename) - output_files = [output_spec.get_header_filepath()] + output_dir: Path = config.get_option("output_directory") + + header_ext = config.extensions.get_option("header") + output_files = [output_dir / f"{basename}.{header_ext}"] if config.output_mode != OutputMode.HEADER_ONLY: - output_files.append(output_spec.get_impl_filepath()) + impl_ext = config.extensions.get_option("impl") + output_files.append(output_dir / f"{basename}.{impl_ext}") - print(args.sep.join(output_files), end=os.linesep if args.newline else "") + print( + args.sep.join(str(of) for of in output_files), + end=os.linesep if args.newline else "", + ) exit(0) -def print_cmake_modulepath(args): +def print_cmake_modulepath(args) -> NoReturn: from .cmake import get_sfg_cmake_modulepath print(get_sfg_cmake_modulepath(), end=os.linesep if args.newline else "") exit(0) -def make_cmake_find_module(args): +def make_cmake_find_module(args) -> NoReturn: from .cmake import make_find_module make_find_module() exit(0) -def abort_with_config_exception(exception: SfgConfigException, source: str): +def abort_with_config_exception(exception: SfgConfigException, source: str) -> NoReturn: print(f"Invalid {source} configuration: {exception.args[0]}.", file=sys.stderr) exit(1) diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index 3ed5c0fe0a432d3fbd8e2e7729a6d0f06a58104f..0a72e8089ecd5e32be53cd335df57b58b21ec578 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -24,6 +24,7 @@ from ..ir import ( SfgVisibilityBlock, SfgEntityDecl, SfgEntityDef, + SfgClassBody, ) from ..exceptions import SfgException @@ -71,7 +72,7 @@ class SfgClassComposer(SfgComposerMixIn): self._args = args return self - def _resolve(self, ctx: SfgContext, cls: SfgClass): + def _resolve(self, ctx: SfgContext, cls: SfgClass) -> SfgVisibilityBlock: vis_block = SfgVisibilityBlock(self._visibility) for arg in self._args: match arg: @@ -86,7 +87,8 @@ class SfgClassComposer(SfgComposerMixIn): var = asvar(arg) member_var = SfgMemberVariable(var.name, var.dtype, cls) cls.add_member(member_var, vis_block.visibility) - vis_block.elements.append(member_var) + vis_block.elements.append(SfgEntityDef(member_var)) + return vis_block class MethodSequencer: def __init__( @@ -133,7 +135,7 @@ class SfgClassComposer(SfgComposerMixIn): def __init__(self, *params: VarLike): self._params = list(asvar(p) for p in params) - self._initializers: list[str] = [] + self._initializers: list[tuple[SfgVar | str, tuple[ExprLike, ...]]] = [] self._body: str | None = None def add_param(self, param: VarLike, at: int | None = None): @@ -152,9 +154,7 @@ class SfgClassComposer(SfgComposerMixIn): member = var if isinstance(var, str) else asvar(var) def init_sequencer(*args: ExprLike): - expr = ", ".join(str(arg) for arg in args) - initializer = f"{member}{{ {expr} }}" - self._initializers.append(initializer) + self._initializers.append((member, args)) return self return init_sequencer @@ -287,18 +287,19 @@ class SfgClassComposer(SfgComposerMixIn): argfilter, args, ) - default_vis_sequencer(*default_vis_args)._resolve(self._ctx, cls) # type: ignore + default_block = default_vis_sequencer(*default_vis_args)._resolve(self._ctx, cls) # type: ignore + vis_blocks: list[SfgVisibilityBlock] = [] for arg in dropwhile(argfilter, args): if isinstance(arg, SfgClassComposer.VisibilityBlockSequencer): - arg._resolve(self._ctx, cls) + vis_blocks.append(arg._resolve(self._ctx, cls)) else: raise SfgException( "Composer Syntax Error: " "Cannot add members with default visibility after a visibility block." ) - self._cursor.write_header(SfgEntityDef(cls)) + self._cursor.write_header(SfgClassBody(cls, default_block, vis_blocks)) return sequencer diff --git a/src/pystencilssfg/config.py b/src/pystencilssfg/config.py index aae9dab541f95c2d7af09da46bce1346150bb4a3..7bbcfc60fb8534e49b244c2ce06f4c00d09165d1 100644 --- a/src/pystencilssfg/config.py +++ b/src/pystencilssfg/config.py @@ -8,9 +8,9 @@ from dataclasses import dataclass from enum import Enum, auto from os import path from importlib import util as iutil +from pathlib import Path - -from pystencils.codegen.config import ConfigBase, BasicOption, Category +from pystencils.codegen.config import ConfigBase, Option, BasicOption, Category class SfgConfigException(Exception): ... # noqa: E701 @@ -166,9 +166,13 @@ class SfgConfig(ConfigBase): ClangFormatOptions.binary """ - output_directory: BasicOption[str] = BasicOption(".") + output_directory: Option[Path, str | Path] = Option(Path(".")) """Directory to which the generated files should be written.""" + @output_directory.validate + def _validate_output_directory(self, pth: str | Path) -> Path: + return Path(pth) + class CommandLineParameters: @staticmethod diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index a129f982e36344410d5fdd58ae5501e112210bc0..f1a32517f3c775435caf135fcaa0e13fa242ad80 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -19,7 +19,7 @@ class SfgContext: def __init__( self, header_file: SfgSourceFile, - impl_file: SfgSourceFile, + impl_file: SfgSourceFile | None, outer_namespace: str | None = None, codestyle: CodeStyle | None = None, argv: Sequence[str] | None = None, diff --git a/src/pystencilssfg/emission/__init__.py b/src/pystencilssfg/emission/__init__.py index fd666283edd12d49c02eb42eee90a5ae8f9756dd..1a22aa26a0bf79acec636488ec6c5e9c00834b78 100644 --- a/src/pystencilssfg/emission/__init__.py +++ b/src/pystencilssfg/emission/__init__.py @@ -1,5 +1,4 @@ -from .emitter import AbstractEmitter, OutputSpec -from .header_impl_pair import HeaderImplPairEmitter -from .header_only import HeaderOnlyEmitter +from .emitter import SfgCodeEmitter +from .file_printer import SfgFilePrinter -__all__ = ["AbstractEmitter", "OutputSpec", "HeaderImplPairEmitter", "HeaderOnlyEmitter"] +__all__ = ["SfgCodeEmitter", "SfgFilePrinter"] diff --git a/src/pystencilssfg/emission/emitter.py b/src/pystencilssfg/emission/emitter.py index c32b18af351579b98477bf688c0789a394fc3732..440344b0d188fb5a2a3b987b9fcbbf029df3a9d5 100644 --- a/src/pystencilssfg/emission/emitter.py +++ b/src/pystencilssfg/emission/emitter.py @@ -1,71 +1,29 @@ from __future__ import annotations -from typing import Sequence -from abc import ABC, abstractmethod -from dataclasses import dataclass -from os import path +from pathlib import Path -from ..context import SfgContext -from ..config import SfgConfig, OutputMode +from ..config import CodeStyle, ClangFormatOptions +from ..ir import SfgSourceFile +from .file_printer import SfgFilePrinter +from .clang_format import invoke_clang_format -@dataclass -class OutputSpec: - """Name and path specification for files output by the code generator. - Filenames are constructed as `<output_directory>/<basename>.<extension>`.""" +class SfgCodeEmitter: + def __init__( + self, + output_directory: Path, + code_style: CodeStyle, + clang_format: ClangFormatOptions, + ): + self._output_dir = output_directory + self._clang_format_opts = clang_format + self._printer = SfgFilePrinter(code_style) - output_directory: str - """Directory to which the generated files should be written.""" + def emit(self, file: SfgSourceFile): + code = self._printer(file) + code = invoke_clang_format(code, self._clang_format_opts) - basename: str - """Base name for output files.""" - - header_extension: str - """File extension for generated header file.""" - - impl_extension: str - """File extension for generated implementation file.""" - - def get_header_filename(self): - return f"{self.basename}.{self.header_extension}" - - def get_impl_filename(self): - return f"{self.basename}.{self.impl_extension}" - - def get_header_filepath(self): - return path.join(self.output_directory, self.get_header_filename()) - - def get_impl_filepath(self): - return path.join(self.output_directory, self.get_impl_filename()) - - @staticmethod - def create(config: SfgConfig, basename: str) -> OutputSpec: - output_mode = config.get_option("output_mode") - header_extension = config.extensions.get_option("header") - impl_extension = config.extensions.get_option("impl") - - if impl_extension is None: - match output_mode: - case OutputMode.INLINE: - impl_extension = "ipp" - case OutputMode.STANDALONE: - impl_extension = "cpp" - - return OutputSpec( - config.get_option("output_directory"), - basename, - header_extension, - impl_extension, - ) - - -class AbstractEmitter(ABC): - @property - @abstractmethod - def output_files(self) -> Sequence[str]: - pass - - @abstractmethod - def write_files(self, ctx: SfgContext): - pass + self._output_dir.mkdir(parents=True, exist_ok=True) + fpath = self._output_dir / file.name + fpath.write_text(code) diff --git a/src/pystencilssfg/emission/file_printer.py b/src/pystencilssfg/emission/file_printer.py index 2dc1e7b4891c9dcfefff8dcfc7e6f68966d379bc..ec434689903fe4599bb86da82765e6c1dfe68768 100644 --- a/src/pystencilssfg/emission/file_printer.py +++ b/src/pystencilssfg/emission/file_printer.py @@ -1,20 +1,35 @@ from __future__ import annotations from textwrap import indent +from pystencils.backend.emission import CAstPrinter + from ..ir import ( SfgSourceFile, SfgSourceFileType, SfgNamespaceBlock, SfgEntityDecl, SfgEntityDef, + SfgKernelHandle, + SfgFunction, + SfgClassMember, + SfgMethod, + SfgMemberVariable, + SfgConstructor, + SfgClass, + SfgClassBody, + SfgVisibilityBlock, + SfgVisibility, ) -from ..ir.syntax import SfgNamespaceElement +from ..ir.syntax import SfgNamespaceElement, SfgClassBodyElement from ..config import CodeStyle class SfgFilePrinter: def __init__(self, code_style: CodeStyle) -> None: self._code_style = code_style + self._kernel_printer = CAstPrinter( + indent_width=code_style.get_option("indent_width") + ) def __call__(self, file: SfgSourceFile) -> str: code = "" @@ -41,17 +56,132 @@ class SfgFilePrinter: code += "\n" return code - def visit(self, elem: SfgNamespaceElement) -> str: + def visit( + self, elem: SfgNamespaceElement | SfgClassBodyElement, inclass: bool = False + ) -> str: match elem: case str(): return elem - case SfgNamespaceBlock(name, elements): - code = f"namespace {name} {{\n" + case SfgNamespaceBlock(namespace, elements): + code = f"namespace {namespace.name} {{\n" code += self._code_style.indent( "\n\n".join(self.visit(e) for e in elements) ) - code += f"\n}} // namespace {name}" + code += f"\n}} // namespace {namespace.name}" + return code case SfgEntityDecl(entity): - code += self.visit_decl(entity) + return self.visit_decl(entity, inclass) case SfgEntityDef(entity): - code += self.visit_defin(entity) + return self.visit_defin(entity, inclass) + case _: + assert False, "illegal code element" + + def visit_decl( + self, + declared_entity: SfgKernelHandle | SfgFunction | SfgClassMember | SfgClass, + inclass: bool = False, + ) -> str: + match declared_entity: + case SfgKernelHandle(kernel): + return self._kernel_printer.print_signature(kernel) + ";" + + case SfgFunction(name, _, params) | SfgMethod(name, _, params): + return self._func_signature(declared_entity, inclass) + ";" + + case SfgConstructor(cls, params): + params_str = ", ".join( + f"{param.dtype.c_string()} {param.name}" for param in params + ) + return f"{cls.name}({params_str});" + + case SfgMemberVariable(name, dtype): + return f"{dtype.c_string()} {name};" + + case SfgClass(kwd, name): + return f"{str(kwd)} {name};" + + case _: + assert False, f"unsupported declared entity: {declared_entity}" + + def visit_defin( + self, + defined_entity: SfgKernelHandle | SfgFunction | SfgClassMember | SfgClassBody, + inclass: bool = False, + ) -> str: + match defined_entity: + case SfgKernelHandle(kernel): + return self._kernel_printer(kernel) + + case SfgFunction(name, tree, params) | SfgMethod(name, tree, params): + sig = self._func_signature(defined_entity, inclass) + body = tree.get_code(self._code_style) + body = "\n{\n" + self._code_style.indent(body) + "\n}" + return sig + body + + case SfgConstructor(cls, params): + params_str = ", ".join( + f"{param.dtype.c_string()} {param.name}" for param in params + ) + + code = "" + if not inclass: + code += f"{cls.name}::" + code += f"{cls.name} ({params_str})" + + inits: list[str] = [] + for var, args in defined_entity.initializers: + args_str = ", ".join(str(arg) for arg in args) + inits.append(f"{str(var)}({args_str})") + + if inits: + code += "\n:" + ",\n".join(inits) + + code += "\n{\n" + self._code_style.indent(defined_entity.body) + "\n}" + return code + + case SfgMemberVariable(name, dtype): + code = dtype.c_string() + if not inclass: + code += f" {defined_entity.owning_class.name}::" + code += f" {name}" + if defined_entity.default_init is not None: + args_str = ", ".join(str(expr) for expr in defined_entity.default_init) + code += "{" + args_str + "}" + code += ";" + return code + + case SfgClassBody(cls, vblocks): + code = f"{cls.class_keyword} {cls.name} {{\n" + vblocks_str = [self._visibility_block(b) for b in vblocks] + code += "\n\n".join(vblocks_str) + code += "\n}\n" + return code + + case _: + assert False, f"unsupported defined entity: {defined_entity}" + + def _visibility_block(self, vblock: SfgVisibilityBlock): + prefix = ( + f"{vblock.visibility}:\n" + if vblock.visibility != SfgVisibility.DEFAULT + else "" + ) + elements = [self.visit(elem, inclass=True) for elem in vblock.elements] + return prefix + self._code_style.indent("\n".join(elements)) + + def _func_signature(self, func: SfgFunction | SfgMethod, inclass: bool): + code = "" + if func.inline: + code += "inline " + code += func.return_type.c_string() + params_str = ", ".join( + f"{param.dtype.c_string()} {param.name}" for param in func.parameters + ) + if isinstance(func, SfgMethod) and not inclass: + code += f"{func.owning_class.name}::" + code += f"{func.name}({params_str})" + + if isinstance(func, SfgMethod) and func.const: + code += " const" + + return code diff --git a/src/pystencilssfg/emission/header_impl_pair.py b/src/pystencilssfg/emission/header_impl_pair.py deleted file mode 100644 index 87ff5f55c1484ffd8e5afd1eb53c0349c7c961e8..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/emission/header_impl_pair.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Sequence -from os import path, makedirs - -from ..context import SfgContext -from .printers import SfgHeaderPrinter, SfgImplPrinter -from .clang_format import invoke_clang_format -from ..config import ClangFormatOptions - -from .emitter import AbstractEmitter, OutputSpec - - -class HeaderImplPairEmitter(AbstractEmitter): - """Emits a header-implementation file pair.""" - - def __init__( - self, - output_spec: OutputSpec, - inline_impl: bool = False, - clang_format: ClangFormatOptions | None = None, - ): - """Create a `HeaderImplPairEmitter` from an [SfgOutputSpec][pystencilssfg.configuration.SfgOutputSpec].""" - self._basename = output_spec.basename - self._output_directory = output_spec.output_directory - self._header_filename = output_spec.get_header_filename() - self._impl_filename = output_spec.get_impl_filename() - self._inline_impl = inline_impl - - self._ospec = output_spec - self._clang_format = clang_format - - @property - def output_files(self) -> Sequence[str]: - """The files that will be written by `write_files`.""" - return ( - path.join(self._output_directory, self._header_filename), - path.join(self._output_directory, self._impl_filename), - ) - - def write_files(self, ctx: SfgContext): - """Write the code represented by the given [SfgContext][pystencilssfg.SfgContext] to the files - specified by the output specification.""" - header_printer = SfgHeaderPrinter(ctx, self._ospec, self._inline_impl) - impl_printer = SfgImplPrinter(ctx, self._ospec, self._inline_impl) - - header = header_printer.get_code() - impl = impl_printer.get_code() - - if self._clang_format is not None: - header = invoke_clang_format(header, self._clang_format) - impl = invoke_clang_format(impl, self._clang_format) - - makedirs(self._output_directory, exist_ok=True) - - with open(self._ospec.get_header_filepath(), "w") as headerfile: - headerfile.write(header) - - with open(self._ospec.get_impl_filepath(), "w") as cppfile: - cppfile.write(impl) diff --git a/src/pystencilssfg/emission/header_only.py b/src/pystencilssfg/emission/header_only.py deleted file mode 100644 index 7d026da7aea8b32495bba6c2298adfb35ff4fb00..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/emission/header_only.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Sequence -from os import path, makedirs - -from ..context import SfgContext -from .printers import SfgHeaderPrinter -from ..config import ClangFormatOptions -from .clang_format import invoke_clang_format - -from .emitter import AbstractEmitter, OutputSpec - - -class HeaderOnlyEmitter(AbstractEmitter): - def __init__( - self, output_spec: OutputSpec, clang_format: ClangFormatOptions | None = None - ): - """Create a `HeaderImplPairEmitter` from an [SfgOutputSpec][pystencilssfg.configuration.SfgOutputSpec].""" - self._basename = output_spec.basename - self._output_directory = output_spec.output_directory - self._header_filename = output_spec.get_header_filename() - - self._ospec = output_spec - self._clang_format = clang_format - - @property - def output_files(self) -> Sequence[str]: - """The files that will be written by `write_files`.""" - return (path.join(self._output_directory, self._header_filename),) - - def write_files(self, ctx: SfgContext): - header_printer = SfgHeaderPrinter(ctx, self._ospec) - header = header_printer.get_code() - if self._clang_format is not None: - header = invoke_clang_format(header, self._clang_format) - - makedirs(self._output_directory, exist_ok=True) - - with open(self._ospec.get_header_filepath(), "w") as headerfile: - headerfile.write(header) diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py deleted file mode 100644 index 9d7c97e7ce732066c91eda3e3cbf887dcb552f77..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/emission/printers.py +++ /dev/null @@ -1,265 +0,0 @@ -from __future__ import annotations - -from textwrap import indent -from itertools import chain, repeat, cycle - -from pystencils.codegen import Kernel -from pystencils.backend.emission import emit_code - -from ..context import SfgContext -from ..visitors import visitor -from ..exceptions import SfgException - -from ..ir.source_components import ( - SfgEmptyLines, - SfgHeaderInclude, - SfgKernelNamespace, - SfgFunction, - SfgClass, - SfgInClassDefinition, - SfgConstructor, - SfgMemberVariable, - SfgMethod, - SfgVisibility, - SfgVisibilityBlock, -) - -from .emitter import OutputSpec - - -def interleave(*iters): - try: - for iter in cycle(iters): - yield next(iter) - except StopIteration: - pass - - -class SfgGeneralPrinter: - @visitor - def visit(self, obj: object) -> str: - raise SfgException(f"Can't print object of type {type(obj)}") - - @visit.case(SfgEmptyLines) - def emptylines(self, el: SfgEmptyLines) -> str: - return "\n" * el.lines - - @visit.case(str) - def string(self, s: str) -> str: - return s - - @visit.case(SfgHeaderInclude) - def include(self, incl: SfgHeaderInclude) -> str: - if incl.system_header: - return f"#include <{incl.file}>" - else: - return f'#include "{incl.file}"' - - def prelude(self, ctx: SfgContext) -> str: - if ctx.prelude_comment: - return ( - "/*\n" - + indent(ctx.prelude_comment, "* ", predicate=lambda _: True) - + "*/\n" - ) - else: - return "" - - def param_list(self, func: SfgFunction) -> str: - params = sorted(list(func.parameters), key=lambda p: p.name) - return ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params) - - -class SfgHeaderPrinter(SfgGeneralPrinter): - def __init__( - self, ctx: SfgContext, output_spec: OutputSpec, inline_impl: bool = False - ): - self._output_spec = output_spec - self._ctx = ctx - self._inline_impl = inline_impl - - def get_code(self) -> str: - return self.visit(self._ctx) - - @visitor - def visit(self, obj: object) -> str: - return super().visit(obj) - - @visit.case(SfgContext) - def frame(self, ctx: SfgContext) -> str: - code = super().prelude(ctx) - - code += "\n#pragma once\n\n" - - includes = filter(lambda incl: not incl.private, ctx.includes()) - code += "\n".join(self.visit(incl) for incl in includes) - code += "\n\n" - - fq_namespace = ctx.fully_qualified_namespace - if fq_namespace is not None: - code += f"namespace {fq_namespace} {{\n\n" - - parts = interleave(ctx.declarations_ordered(), repeat(SfgEmptyLines(1))) - - code += "\n".join(self.visit(p) for p in parts) - - if fq_namespace is not None: - code += f"}} // namespace {fq_namespace}\n" - - if self._inline_impl: - code += f'#include "{self._output_spec.get_impl_filename()}"\n' - - return code - - @visit.case(SfgFunction) - def function(self, func: SfgFunction): - params = sorted(list(func.parameters), key=lambda p: p.name) - param_list = ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params) - return f"{func.return_type.c_string()} {func.name} ( {param_list} );" - - @visit.case(SfgClass) - def sfg_class(self, cls: SfgClass): - code = f"{cls.class_keyword} {cls.class_name} \n" - - if cls.base_classes: - code += f" : {','.join(cls.base_classes)}\n" - - code += "{\n" - - for block in cls.visibility_blocks(): - code += self.visit(block) + "\n" - - code += "};\n" - - return code - - @visit.case(SfgVisibilityBlock) - def vis_block(self, block: SfgVisibilityBlock) -> str: - code = "" - if block.visibility != SfgVisibility.DEFAULT: - code += f"{block.visibility}:\n" - code += self._ctx.codestyle.indent( - "\n".join(self.visit(m) for m in block.members()) - ) - return code - - @visit.case(SfgInClassDefinition) - def sfg_inclassdef(self, definition: SfgInClassDefinition): - return definition.text - - @visit.case(SfgConstructor) - def sfg_constructor(self, constr: SfgConstructor): - code = f"{constr.owning_class.class_name} (" - code += ", ".join(f"{param.dtype.c_string()} {param.name}" for param in constr.parameters) - code += ")\n" - if constr.initializers: - code += " : " + ", ".join(constr.initializers) + "\n" - if constr.body: - code += "{\n" + self._ctx.codestyle.indent(constr.body) + "\n}\n" - else: - code += "{ }\n" - return code - - @visit.case(SfgMemberVariable) - def sfg_member_var(self, var: SfgMemberVariable): - return f"{var.dtype.c_string()} {var.name};" - - @visit.case(SfgMethod) - def sfg_method(self, method: SfgMethod): - code = f"{method.return_type.c_string()} {method.name} ({self.param_list(method)})" - code += "const" if method.const else "" - if method.inline: - code += ( - " {\n" - + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) - + "}\n" - ) - else: - code += ";" - return code - - -class SfgImplPrinter(SfgGeneralPrinter): - def __init__( - self, ctx: SfgContext, output_spec: OutputSpec, inline_impl: bool = False - ): - self._output_spec = output_spec - self._ctx = ctx - self._inline_impl = inline_impl - - def get_code(self) -> str: - return self.visit(self._ctx) - - @visitor - def visit(self, obj: object) -> str: - return super().visit(obj) - - @visit.case(SfgContext) - def frame(self, ctx: SfgContext) -> str: - code = super().prelude(ctx) - - if not self._inline_impl: - code += f'\n#include "{self._output_spec.get_header_filename()}"\n\n' - - includes = filter(lambda incl: incl.private, ctx.includes()) - code += "\n".join(self.visit(incl) for incl in includes) - - code += "\n\n#define FUNC_PREFIX inline\n\n" - - fq_namespace = ctx.fully_qualified_namespace - if fq_namespace is not None: - code += f"namespace {fq_namespace} {{\n\n" - - parts = interleave( - chain( - ctx.kernel_namespaces(), - ctx.functions(), - ctx.classes(), - ), - repeat(SfgEmptyLines(1)), - ) - - code += "\n".join(self.visit(p) for p in parts) - - if fq_namespace is not None: - code += f"}} // namespace {fq_namespace}\n" - - return code - - @visit.case(SfgKernelNamespace) - def kernel_namespace(self, kns: SfgKernelNamespace) -> str: - code = f"namespace {kns.name} {{\n\n" - code += "\n\n".join(self.visit(ast) for ast in kns.kernel_functions) - code += f"\n}} // namespace {kns.name}\n" - return code - - @visit.case(Kernel) - def kernel(self, kfunc: Kernel) -> str: - return emit_code(kfunc) - - @visit.case(SfgFunction) - def function(self, func: SfgFunction) -> str: - inline_prefix = "inline " if self._inline_impl else "" - code = ( - f"{inline_prefix} {func.return_type.c_string()} {func.name} ({self.param_list(func)})" - ) - code += ( - "{\n" + self._ctx.codestyle.indent(func.tree.get_code(self._ctx)) + "}\n" - ) - return code - - @visit.case(SfgClass) - def sfg_class(self, cls: SfgClass) -> str: - methods = filter(lambda m: not m.inline, cls.methods()) - return "\n".join(self.visit(m) for m in methods) - - @visit.case(SfgMethod) - def sfg_method(self, method: SfgMethod) -> str: - inline_prefix = "inline " if self._inline_impl else "" - const_qual = "const" if method.const else "" - code = f"{inline_prefix}{method.return_type} {method.owning_class.class_name}::{method.name}" - code += f"({self.param_list(method)}) {const_qual}" - code += ( - " {\n" + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) + "}\n" - ) - return code diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index a628f00802b7863eb2f67a9a84598e4aa1be4e01..48f9c08d8754d9a6626109b1c037983a923b33f2 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -8,6 +8,7 @@ from pystencils import Target from pystencilssfg.composer.basic_composer import SequencerArg +from ..config import CodeStyle from ..exceptions import SfgException from ..context import SfgContext from ..composer import ( @@ -93,11 +94,11 @@ class SyclHandler(AugExpr): if isinstance(range, _VarLike): range = asvar(range) - def check_kernel(kernel: SfgKernelHandle): - kfunc = kernel.get_kernel_function() + def check_kernel(khandle: SfgKernelHandle): + kfunc = khandle.kernel if kfunc.target != Target.SYCL: raise SfgException( - f"Kernel given to `parallel_for` is no SYCL kernel: {kernel.kernel_name}" + f"Kernel given to `parallel_for` is no SYCL kernel: {khandle.fqname}" ) id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>") @@ -144,7 +145,7 @@ class SyclGroup(AugExpr): self._ctx = ctx def parallel_for_work_item( - self, range: VarLike | Sequence[int], kernel: SfgKernelHandle + self, range: VarLike | Sequence[int], khandle: SfgKernelHandle ): """Generate a ``parallel_for_work_item` kernel invocation on this group.` @@ -155,10 +156,10 @@ class SyclGroup(AugExpr): if isinstance(range, _VarLike): range = asvar(range) - kfunc = kernel.get_kernel_function() + kfunc = khandle.kernel if kfunc.target != Target.SYCL: raise SfgException( - f"Kernel given to `parallel_for` is no SYCL kernel: {kernel.kernel_name}" + f"Kernel given to `parallel_for` is no SYCL kernel: {khandle.fqname}" ) id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>") @@ -169,13 +170,13 @@ class SyclGroup(AugExpr): and id_regex.search(param.dtype.c_string()) is not None ) - id_param = list(filter(filter_id, kernel.scalar_parameters))[0] + id_param = list(filter(filter_id, khandle.scalar_parameters))[0] h_item = SfgVar("item", PsCustomType("sycl::h_item< 3 >")) comp = SfgComposer(self._ctx) tree = comp.seq( comp.set_param(id_param, AugExpr.format("{}.get_local_id()", h_item)), - SfgKernelCallNode(kernel), + SfgKernelCallNode(khandle), ) kernel_lambda = SfgLambda(("=",), (h_item,), tree, None) @@ -229,11 +230,11 @@ class SfgLambda: def required_parameters(self) -> set[SfgVar]: return self._required_params - def get_code(self, ctx: SfgContext): + def get_code(self, cstyle: CodeStyle): captures = ", ".join(self._captures) params = ", ".join(f"{p.dtype.c_string()} {p.name}" for p in self._params) - body = self._tree.get_code(ctx) - body = ctx.codestyle.indent(body) + body = self._tree.get_code(cstyle) + body = cstyle.indent(body) rtype = ( f"-> {self._return_type.c_string()} " if self._return_type is not None @@ -300,13 +301,13 @@ class SyclKernelInvoke(SfgCallTreeLeaf): def depends(self) -> set[SfgVar]: return self._required_params - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: if isinstance(self._range, SfgVar): range_code = self._range.name else: range_code = "{ " + ", ".join(str(r) for r in self._range) + " }" - kernel_code = self._lambda.get_code(ctx) + kernel_code = self._lambda.get_code(cstyle) invoker = str(self._invoker) method = self._invoke_type.method diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py index eed06b0d684850e462b86aa11b53bc4ca1586185..b055ad518d6f198d5f8aabd8026353177ae3c70a 100644 --- a/src/pystencilssfg/generator.py +++ b/src/pystencilssfg/generator.py @@ -1,10 +1,9 @@ -import os -from os import path +from pathlib import Path from .config import SfgConfig, CommandLineParameters, OutputMode, GLOBAL_NAMESPACE from .context import SfgContext from .composer import SfgComposer -from .emission import AbstractEmitter, OutputSpec +from .emission import SfgCodeEmitter from .exceptions import SfgException @@ -41,9 +40,9 @@ class SourceFileGenerator: "without a valid entry point, such as a REPL or a multiprocessing fork." ) - scriptpath = __main__.__file__ - scriptname = path.split(scriptpath)[1] - basename = path.splitext(scriptname)[0] + scriptpath = Path(__main__.__file__) + scriptname = scriptpath.name + basename = scriptname.rsplit(".")[0] from argparse import ArgumentParser @@ -67,47 +66,52 @@ class SourceFileGenerator: cli_params.find_conflicts(sfg_config) config.override(sfg_config) - self._context = SfgContext( - None if config.outer_namespace is GLOBAL_NAMESPACE else config.outer_namespace, # type: ignore - config.codestyle, - argv=script_args, - project_info=cli_params.get_project_info(), - ) + self._output_mode: OutputMode = config.get_option("output_mode") + self._output_dir: Path = config.get_option("output_directory") + self._header_ext: str = config.extensions.get_option("header") + self._impl_ext: str = config.extensions.get_option("impl") - from .lang import HeaderFile - from .ir import SfgHeaderInclude + from .ir import SfgSourceFile, SfgSourceFileType - self._context.add_include(SfgHeaderInclude(HeaderFile("cstdint", system_header=True))) - self._context.add_definition("#define RESTRICT __restrict__") - - output_mode = config.get_option("output_mode") - output_spec = OutputSpec.create(config, basename) + self._header_file = SfgSourceFile( + f"{basename}.{self._header_ext}", SfgSourceFileType.HEADER + ) + self._impl_file: SfgSourceFile | None - self._emitter: AbstractEmitter - match output_mode: + match self._output_mode: case OutputMode.HEADER_ONLY: - from .emission import HeaderOnlyEmitter - - self._emitter = HeaderOnlyEmitter( - output_spec, clang_format=config.clang_format + self._impl_file = None + case OutputMode.STANDALONE: + self._impl_file = SfgSourceFile( + f"{basename}.{self._impl_ext}", SfgSourceFileType.TRANSLATION_UNIT ) case OutputMode.INLINE: - from .emission import HeaderImplPairEmitter - - self._emitter = HeaderImplPairEmitter( - output_spec, inline_impl=True, clang_format=config.clang_format + self._impl_file = SfgSourceFile( + f"{basename}.{self._impl_ext}", SfgSourceFileType.HEADER ) - case OutputMode.STANDALONE: - from .emission import HeaderImplPairEmitter - self._emitter = HeaderImplPairEmitter( - output_spec, clang_format=config.clang_format - ) + self._context = SfgContext( + self._header_file, + self._impl_file, + None if config.outer_namespace is GLOBAL_NAMESPACE else config.outer_namespace, # type: ignore + config.codestyle, + argv=script_args, + project_info=cli_params.get_project_info(), + ) + + self._emitter = SfgCodeEmitter( + self._output_dir, config.codestyle, config.clang_format + ) def clean_files(self): - for file in self._emitter.output_files: - if path.exists(file): - os.remove(file) + header_path = self._output_dir / self._header_file.name + if header_path.exists(): + header_path.unlink() + + if self._impl_file is not None: + impl_path = self._output_dir / self._impl_file.name + if impl_path.exists(): + impl_path.unlink() def __enter__(self) -> SfgComposer: self.clean_files() @@ -116,8 +120,12 @@ class SourceFileGenerator: def __exit__(self, exc_type, exc_value, traceback): if exc_type is None: # Collect header files for inclusion - from .ir import SfgHeaderInclude, collect_includes - for header in collect_includes(self._context): - self._context.add_include(SfgHeaderInclude(header)) + # from .ir import collect_includes + + # TODO: Collect headers + # for header in collect_includes(self._context): + # self._context.add_include(SfgHeaderInclude(header)) - self._emitter.write_files(self._context) + self._emitter.emit(self._header_file) + if self._impl_file is not None: + self._emitter.emit(self._impl_file) diff --git a/src/pystencilssfg/ir/analysis.py b/src/pystencilssfg/ir/analysis.py index c2c3e34675585568115d1a4a06f0867eb9ca563c..b88c4f3fe3e2265c281734ec8725536cb609c130 100644 --- a/src/pystencilssfg/ir/analysis.py +++ b/src/pystencilssfg/ir/analysis.py @@ -8,7 +8,6 @@ from ..lang import HeaderFile, includes def collect_includes(obj: Any) -> set[HeaderFile]: - from ..context import SfgContext from .call_tree import SfgCallTreeNode from .entities import ( SfgFunction, @@ -18,15 +17,7 @@ def collect_includes(obj: Any) -> set[HeaderFile]: ) match obj: - case SfgContext(): - headers = set() - for func in obj.functions(): - headers |= collect_includes(func) - - for cls in obj.classes(): - headers |= collect_includes(cls) - - return headers + # TODO case SfgCallTreeNode(): return reduce( @@ -48,7 +39,7 @@ def collect_includes(obj: Any) -> set[HeaderFile]: set(), ) - case SfgConstructor(parameters): + case SfgConstructor(_, parameters): param_headers = reduce( set.union, (includes(p) for p in parameters), set() ) diff --git a/src/pystencilssfg/ir/call_tree.py b/src/pystencilssfg/ir/call_tree.py index 2e057dd1b2864bfaf2845501cccf3ff10673640c..4cee2f526c3a26e34f6b122a3dd1d8a15dd11563 100644 --- a/src/pystencilssfg/ir/call_tree.py +++ b/src/pystencilssfg/ir/call_tree.py @@ -7,7 +7,7 @@ from .entities import SfgKernelHandle from ..lang import SfgVar, HeaderFile if TYPE_CHECKING: - from ..context import SfgContext + from ..config import CodeStyle class SfgCallTreeNode(ABC): @@ -35,7 +35,7 @@ class SfgCallTreeNode(ABC): """This node's children""" @abstractmethod - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: """Returns the code of this node. By convention, the code block emitted by this function should not contain a trailing newline. @@ -75,7 +75,7 @@ class SfgEmptyNode(SfgCallTreeLeaf): def __init__(self): super().__init__() - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: return "" @@ -122,7 +122,7 @@ class SfgStatements(SfgCallTreeLeaf): def code_string(self) -> str: return self._code_string - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: return self._code_string @@ -167,8 +167,8 @@ class SfgSequence(SfgCallTreeNode): def __setitem__(self, idx: int, c: SfgCallTreeNode): self._children[idx] = c - def get_code(self, ctx: SfgContext) -> str: - return "\n".join(c.get_code(ctx) for c in self._children) + def get_code(self, cstyle: CodeStyle) -> str: + return "\n".join(c.get_code(cstyle) for c in self._children) class SfgBlock(SfgCallTreeNode): @@ -184,8 +184,8 @@ class SfgBlock(SfgCallTreeNode): def children(self) -> Sequence[SfgCallTreeNode]: return (self._seq,) - def get_code(self, ctx: SfgContext) -> str: - seq_code = ctx.codestyle.indent(self._seq.get_code(ctx)) + def get_code(self, cstyle: CodeStyle) -> str: + seq_code = cstyle.indent(self._seq.get_code(cstyle)) return "{\n" + seq_code + "\n}" @@ -208,7 +208,7 @@ class SfgKernelCallNode(SfgCallTreeLeaf): def depends(self) -> set[SfgVar]: return set(self._kernel_handle.parameters) - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: ast_params = self._kernel_handle.parameters fnc_name = self._kernel_handle.fqname call_parameters = ", ".join([p.name for p in ast_params]) @@ -228,8 +228,8 @@ class SfgCudaKernelInvocation(SfgCallTreeLeaf): from pystencils import Target from pystencils.codegen import GpuKernel - func = kernel_handle.get_kernel() - if not (isinstance(func, GpuKernel) and func.target == Target.CUDA): + kernel = kernel_handle.kernel + if not (isinstance(kernel, GpuKernel) and kernel.target == Target.CUDA): raise ValueError( "An `SfgCudaKernelInvocation` node can only call a CUDA kernel." ) @@ -245,7 +245,7 @@ class SfgCudaKernelInvocation(SfgCallTreeLeaf): def depends(self) -> set[SfgVar]: return set(self._kernel_handle.parameters) | self._depends - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: ast_params = self._kernel_handle.parameters fnc_name = self._kernel_handle.fqname call_parameters = ", ".join([p.name for p in ast_params]) @@ -289,14 +289,14 @@ class SfgBranch(SfgCallTreeNode): self._branch_true, ) + ((self.branch_false,) if self.branch_false is not None else ()) - def get_code(self, ctx: SfgContext) -> str: - code = f"if({self.condition.get_code(ctx)}) {{\n" - code += ctx.codestyle.indent(self.branch_true.get_code(ctx)) + def get_code(self, cstyle: CodeStyle) -> str: + code = f"if({self.condition.get_code(cstyle)}) {{\n" + code += cstyle.indent(self.branch_true.get_code(cstyle)) code += "\n}" if self.branch_false is not None: code += "else {\n" - code += ctx.codestyle.indent(self.branch_false.get_code(ctx)) + code += cstyle.indent(self.branch_false.get_code(cstyle)) code += "\n}" return code @@ -327,13 +327,13 @@ class SfgSwitchCase(SfgCallTreeNode): def is_default(self) -> bool: return self._label == SfgSwitchCase.Default - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: code = "" if self._label == SfgSwitchCase.Default: code += "default: {\n" else: code += f"case {self._label}: {{\n" - code += ctx.codestyle.indent(self.body.get_code(ctx)) + code += cstyle.indent(self.body.get_code(cstyle)) code += "\n}" return code @@ -403,8 +403,8 @@ class SfgSwitch(SfgCallTreeNode): else: self._children[idx] = c - def get_code(self, ctx: SfgContext) -> str: - code = f"switch({self._switch_arg.get_code(ctx)}) {{\n" - code += "\n".join(c.get_code(ctx) for c in self._cases) + def get_code(self, cstyle: CodeStyle) -> str: + code = f"switch({self._switch_arg.get_code(cstyle)}) {{\n" + code += "\n".join(c.get_code(cstyle) for c in self._cases) code += "}" return code diff --git a/src/pystencilssfg/ir/entities.py b/src/pystencilssfg/ir/entities.py index 9d76b39fbf68afe34f62f150aa32b4557be0a840..2b0e3e5c3e5fd2f9db9f49922fed3bdedad331ef 100644 --- a/src/pystencilssfg/ir/entities.py +++ b/src/pystencilssfg/ir/entities.py @@ -14,7 +14,7 @@ from pystencils import CreateKernelConfig, create_kernel, Field from pystencils.codegen import Kernel from pystencils.types import PsType, PsCustomType -from ..lang import SfgVar, SfgKernelParamVar, void +from ..lang import SfgVar, SfgKernelParamVar, void, ExprLike from ..exceptions import SfgException if TYPE_CHECKING: @@ -100,6 +100,8 @@ class SfgGlobalNamespace(SfgNamespace): class SfgKernelHandle(SfgCodeEntity): """Handle to a pystencils kernel.""" + __match_args__ = ("kernel",) + def __init__(self, name: str, namespace: SfgKernelNamespace, kernel: Kernel): super().__init__(name, namespace) @@ -130,7 +132,8 @@ class SfgKernelHandle(SfgCodeEntity): """Fields accessed by this kernel""" return self._fields - def get_kernel(self) -> Kernel: + @property + def kernel(self) -> Kernel: """Underlying pystencils kernel object""" return self._kernel @@ -315,9 +318,20 @@ class SfgClassMember(ABC): class SfgMemberVariable(SfgVar, SfgClassMember): """Variable that is a field of a class""" - def __init__(self, name: str, dtype: PsType, cls: SfgClass): + def __init__( + self, + name: str, + dtype: PsType, + cls: SfgClass, + default_init: tuple[ExprLike, ...] | None = None, + ): SfgVar.__init__(self, name, dtype) SfgClassMember.__init__(self, cls) + self._default_init = default_init + + @property + def default_init(self) -> tuple[ExprLike, ...] | None: + return self._default_init class SfgMethod(SfgClassMember): @@ -349,6 +363,10 @@ class SfgMethod(SfgClassMember): param_collector = CallTreePostProcessing() self._parameters = param_collector(self._tree).function_params + @property + def name(self) -> str: + return self._name + @property def parameters(self) -> set[SfgVar]: return self._parameters @@ -373,13 +391,13 @@ class SfgMethod(SfgClassMember): class SfgConstructor(SfgClassMember): """Constructor of a class""" - __match_args__ = ("parameters", "initializers", "body") + __match_args__ = ("owning_class", "parameters", "initializers", "body") def __init__( self, cls: SfgClass, parameters: Sequence[SfgVar] = (), - initializers: Sequence[str] = (), + initializers: Sequence[tuple[SfgVar | str, tuple[ExprLike, ...]]] = (), body: str = "", ): super().__init__(cls) @@ -392,7 +410,7 @@ class SfgConstructor(SfgClassMember): return self._parameters @property - def initializers(self) -> tuple[str, ...]: + def initializers(self) -> tuple[tuple[SfgVar | str, tuple[ExprLike, ...]], ...]: return self._initializers @property @@ -403,7 +421,7 @@ class SfgConstructor(SfgClassMember): class SfgClass(SfgCodeEntity): """A C++ class.""" - __match_args__ = ("class_name",) + __match_args__ = ("class_keyword", "name") def __init__( self, diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index ca6d9f21b8d91230053b4d06a93698d357c17e5e..5563783a5fa14f637e0f16fd0579f266cc9c0c27 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence, Iterable +from typing import Sequence, Iterable import warnings from functools import reduce from dataclasses import dataclass @@ -13,6 +13,7 @@ from pystencils.types import deconstify, PsType from pystencils.codegen.properties import FieldBasePtr, FieldShape, FieldStride from ..exceptions import SfgException +from ..config import CodeStyle from .call_tree import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements from ..lang.expressions import SfgKernelParamVar @@ -27,9 +28,6 @@ from ..lang import ( includes, ) -if TYPE_CHECKING: - from ..context import SfgContext - class FlattenSequences: """Flattens any nested sequences occuring in a kernel call tree.""" @@ -198,7 +196,7 @@ class SfgDeferredNode(SfgCallTreeNode, ABC): def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: pass - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: raise SfgException( "Invalid access into deferred node; deferred nodes must be expanded first." ) diff --git a/src/pystencilssfg/ir/syntax.py b/src/pystencilssfg/ir/syntax.py index 42b705b46686a24dca8c181db6e54e8ee69fda8c..574c02930d224b6f130ad7b7b337abdb41af36c2 100644 --- a/src/pystencilssfg/ir/syntax.py +++ b/src/pystencilssfg/ir/syntax.py @@ -64,12 +64,7 @@ class SfgEntityDef(Generic[SourceEntity_T]): return self._entity -SfgClassBodyElement = ( - str - | SfgEntityDecl[SfgClassMember] - | SfgEntityDef[SfgClassMember] - | SfgMemberVariable -) +SfgClassBodyElement = str | SfgEntityDecl[SfgClassMember] | SfgEntityDef[SfgClassMember] """Elements that may be placed in the visibility blocks of a class body.""" @@ -122,7 +117,10 @@ class SfgNamespaceBlock: parent: Parent namespace enclosing this namespace """ - __match_args__ = ("name", "elements",) + __match_args__ = ( + "namespace", + "elements", + ) def __init__(self, namespace: SfgNamespace) -> None: self._namespace = namespace @@ -145,10 +143,22 @@ class SfgNamespaceBlock: class SfgClassBody: """Body of a class definition.""" - def __init__(self, cls: SfgClass) -> None: + __match_args__ = ("associated_class", "visibility_blocks") + + def __init__( + self, + cls: SfgClass, + default_block: SfgVisibilityBlock, + vis_blocks: Iterable[SfgVisibilityBlock], + ) -> None: self._cls = cls - self._default_block = SfgVisibilityBlock(SfgVisibility.DEFAULT) - self._blocks = [self._default_block] + assert default_block.visibility == SfgVisibility.DEFAULT + self._default_block = default_block + self._blocks = [self._default_block] + list(vis_blocks) + + @property + def associated_class(self) -> SfgClass: + return self._cls @property def default(self) -> SfgVisibilityBlock: @@ -161,11 +171,14 @@ class SfgClassBody: ) self._blocks.append(block) + @property def visibility_blocks(self) -> tuple[SfgVisibilityBlock, ...]: return tuple(self._blocks) -SfgNamespaceElement = str | SfgNamespaceBlock | SfgEntityDecl | SfgEntityDef +SfgNamespaceElement = ( + str | SfgNamespaceBlock | SfgClassBody | SfgEntityDecl | SfgEntityDef +) """Elements that may be placed inside a namespace, including the global namespace.""" diff --git a/src/pystencilssfg/visitors/__init__.py b/src/pystencilssfg/visitors/__init__.py deleted file mode 100644 index fc7af1b6363bdb1167ee6c0e164ba87e31fbb6b0..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/visitors/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .dispatcher import visitor - -__all__ = [ - "visitor", -] diff --git a/src/pystencilssfg/visitors/dispatcher.py b/src/pystencilssfg/visitors/dispatcher.py deleted file mode 100644 index 85a0f087bbf5c1ffcfb99d3ec3acbdbda77089ab..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/visitors/dispatcher.py +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import annotations -from typing import Callable, TypeVar, Generic -from types import MethodType - -from functools import wraps - -V = TypeVar("V") -R = TypeVar("R") - - -class VisitorDispatcher(Generic[V, R]): - def __init__(self, wrapped_method: Callable[..., R]): - self._dispatch_dict: dict[type, Callable[..., R]] = {} - self._wrapped_method: Callable[..., R] = wrapped_method - - def case(self, node_type: type): - """Decorator for visitor's case handlers.""" - - def decorate(handler: Callable[..., R]): - if node_type in self._dispatch_dict: - raise ValueError(f"Duplicate visitor case {node_type}") - self._dispatch_dict[node_type] = handler - return handler - - return decorate - - def __call__(self, instance: V, node: object, *args, **kwargs) -> R: - for cls in node.__class__.mro(): - if cls in self._dispatch_dict: - return self._dispatch_dict[cls](instance, node, *args, **kwargs) - - return self._wrapped_method(instance, node, *args, **kwargs) - - def __get__(self, obj: V, objtype=None) -> Callable[..., R]: - if obj is None: - return self - return MethodType(self, obj) - - -def visitor(method): - """Decorator to create a visitor using type-based dispatch. - - Use this decorator to convert a method into a visitor, like shown below. - After declaring a method (e.g. `my_method`) a visitor, - its case handlers can be declared using the `my_method.case` decorator, like this: - - ```Python - class DemoVisitor: - @visitor - def visit(self, obj: object): - # fallback case - ... - - @visit.case(str) - def visit_str(self, obj: str): - # code for handling a str - ``` - - When `visit` is later called with some object `x`, the case handler to be executed is - determined according to the method resolution order of `x` (i.e. along its type's inheritance hierarchy). - If no case matches, the fallback code in the original visitor method is executed. - In this example, if `visit` is called with an object of type `str`, the call is dispatched to `visit_str`. - - This visitor dispatch method is primarily designed for traversing abstract syntax tree structures. - The primary visitor method (`visit` in above example) should define the common parent type of all object - types the visitor can handle, with cases declared for all required subtypes. - However, this type relationship is not enforced at runtime. - """ - return wraps(method)(VisitorDispatcher(method))