diff --git a/conftest.py b/conftest.py index 661e722446f25fca1230c55989f00372bf66802c..287ed04129dbeb00b52fb7016de9861d7f952c11 100644 --- a/conftest.py +++ b/conftest.py @@ -2,21 +2,30 @@ import pytest from os import path -@pytest.fixture(autouse=True) -def prepare_doctest_namespace(doctest_namespace): - from pystencilssfg import SfgContext, SfgComposer - from pystencilssfg import lang - - # Place a composer object in the environment for doctests - - sfg = SfgComposer(SfgContext()) - doctest_namespace["sfg"] = sfg - doctest_namespace["lang"] = lang - - DATA_DIR = path.join(path.split(__file__)[0], "tests/data") @pytest.fixture def sample_config_module(): return path.join(DATA_DIR, "project_config.py") + + +@pytest.fixture +def sfg(): + from pystencilssfg import SfgContext, SfgComposer + from pystencilssfg.ir import SfgSourceFile, SfgSourceFileType + + return SfgComposer( + SfgContext( + header_file=SfgSourceFile("", SfgSourceFileType.HEADER), + impl_file=SfgSourceFile("", SfgSourceFileType.TRANSLATION_UNIT), + ) + ) + + +@pytest.fixture(autouse=True) +def prepare_doctest_namespace(doctest_namespace, sfg): + from pystencilssfg import lang + + doctest_namespace["sfg"] = sfg + doctest_namespace["lang"] = lang diff --git a/docs/source/conf.py b/docs/source/conf.py index da6f4d729898cb0215658f99bf1ecea2b018edb5..d6aab17bc3f667bbfc923c80c2d1c35b9e08a3d7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -59,7 +59,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3.8", None), "numpy": ("https://numpy.org/doc/stable/", None), "sympy": ("https://docs.sympy.org/latest/", None), - "pystencils": ("https://da15siwa.pages.i10git.cs.fau.de/dev-docs/pystencils-nbackend/", None), + "pystencils": ("https://pycodegen.pages.i10git.cs.fau.de/docs/pystencils/2.0dev/", None), } # References diff --git a/docs/source/usage/generator_scripts.md b/docs/source/usage/generator_scripts.md index 4a1f6aa7c34ae4667b938d80fc8bd4b595050361..3141feec607c7eff68394288624b67a192d84ad0 100644 --- a/docs/source/usage/generator_scripts.md +++ b/docs/source/usage/generator_scripts.md @@ -64,7 +64,6 @@ Structure and Verbatim Code: SfgBasicComposer.include SfgBasicComposer.namespace SfgBasicComposer.code - SfgBasicComposer.define_once ``` Kernels and Kernel Namespaces: diff --git a/pyproject.toml b/pyproject.toml index 6ac0327d728b8732d86186e213892ae60134a2ab..93987d5f58adf258c5f5f5ad82acef97523d3e76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ parentdir_prefix = "pystencilssfg-" [tool.coverage.run] omit = [ "setup.py", + "noxfile.py", "src/pystencilssfg/_version.py", "integration/*" ] @@ -68,4 +69,5 @@ exclude_also = [ "\\.\\.\\.\n", "if TYPE_CHECKING:", "@(abc\\.)?abstractmethod", + "assert False" ] diff --git a/src/pystencilssfg/__init__.py b/src/pystencilssfg/__init__.py index b2def3b84ba3eab0706b095878a9c165322c1d8b..fea6f8a10e198f86f18be34a6cad381ed3238e19 100644 --- a/src/pystencilssfg/__init__.py +++ b/src/pystencilssfg/__init__.py @@ -1,5 +1,5 @@ -from .config import SfgConfig -from .generator import SourceFileGenerator, GLOBAL_NAMESPACE, OutputMode +from .config import SfgConfig, GLOBAL_NAMESPACE, OutputMode +from .generator import SourceFileGenerator from .composer import SfgComposer from .context import SfgContext from .lang import SfgVar, AugExpr diff --git a/src/pystencilssfg/cli.py b/src/pystencilssfg/cli.py index 3b321c2d0be13b906384701080b8de870647d507..d612fbc886942f68ea4d24044d1ed5d3ee4cf56c 100644 --- a/src/pystencilssfg/cli.py +++ b/src/pystencilssfg/cli.py @@ -1,11 +1,11 @@ import sys import os from os import path +from typing import NoReturn from argparse import ArgumentParser, BooleanOptionalAction -from .config import CommandLineParameters, SfgConfigException, OutputMode -from .emission import OutputSpec +from .config import CommandLineParameters, SfgConfigException def add_newline_arg(parser): @@ -17,7 +17,7 @@ def add_newline_arg(parser): ) -def cli_main(program="sfg-cli"): +def cli_main(program="sfg-cli") -> NoReturn: parser = ArgumentParser( program, description="pystencilssfg command-line utility for build system integration", @@ -65,7 +65,7 @@ def cli_main(program="sfg-cli"): exit(-1) # should never happen -def version(args): +def version(args) -> NoReturn: from . import __version__ print(__version__, end=os.linesep if args.newline else "") @@ -73,37 +73,37 @@ def version(args): exit(0) -def list_files(args): +def list_files(args) -> NoReturn: cli_params = CommandLineParameters(args) config = cli_params.get_config() _, scriptname = path.split(args.codegen_script) basename = path.splitext(scriptname)[0] - output_spec = OutputSpec.create(config, basename) - output_files = [output_spec.get_header_filepath()] - if config.output_mode != OutputMode.HEADER_ONLY: - output_files.append(output_spec.get_impl_filepath()) + output_files = config._get_output_files(basename) - print(args.sep.join(output_files), end=os.linesep if args.newline else "") + print( + args.sep.join(str(of) for of in output_files), + end=os.linesep if args.newline else "", + ) exit(0) -def print_cmake_modulepath(args): +def print_cmake_modulepath(args) -> NoReturn: from .cmake import get_sfg_cmake_modulepath print(get_sfg_cmake_modulepath(), end=os.linesep if args.newline else "") exit(0) -def make_cmake_find_module(args): +def make_cmake_find_module(args) -> NoReturn: from .cmake import make_find_module make_find_module() exit(0) -def abort_with_config_exception(exception: SfgConfigException, source: str): +def abort_with_config_exception(exception: SfgConfigException, source: str) -> NoReturn: print(f"Invalid {source} configuration: {exception.args[0]}.", file=sys.stderr) exit(1) diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index b96d559a1733ec69b6f40db04f41437cc80e3ad7..e75f0e29e4d41825a8d0d8cf1547bab046c1731a 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -1,11 +1,11 @@ from __future__ import annotations from typing import Sequence, TypeAlias from abc import ABC, abstractmethod -import numpy as np import sympy as sp from functools import reduce +from warnings import warn -from pystencils import Field +from pystencils import Field, CreateKernelConfig, create_kernel from pystencils.codegen import Kernel from pystencils.types import create_type, UserTypeSpec @@ -28,15 +28,13 @@ from ..ir.postprocessing import ( SfgDeferredFieldMapping, SfgDeferredVectorMapping, ) -from ..ir.source_components import ( +from ..ir import ( SfgFunction, - SfgHeaderInclude, SfgKernelNamespace, SfgKernelHandle, - SfgClass, - SfgConstructor, - SfgMemberVariable, - SfgClassKeyword, + SfgEntityDecl, + SfgEntityDef, + SfgNamespaceBlock, ) from ..lang import ( VarLike, @@ -60,6 +58,7 @@ from ..exceptions import SfgException class SfgIComposer(ABC): def __init__(self, ctx: SfgContext): self._ctx = ctx + self._cursor = ctx.cursor @property def context(self): @@ -79,6 +78,68 @@ SequencerArg: TypeAlias = tuple | ExprLike | SfgCallTreeNode | SfgNodeBuilder """Valid arguments to `make_sequence` and any sequencer that uses it.""" +class KernelsAdder: + def __init__(self, ctx: SfgContext, loc: SfgNamespaceBlock): + self._ctx = ctx + self._loc = loc + assert isinstance(loc.namespace, SfgKernelNamespace) + self._kernel_namespace = loc.namespace + + def add(self, kernel: Kernel, name: str | None = None): + """Adds an existing pystencils AST to this namespace. + If a name is specified, the AST's function name is changed.""" + if name is None: + kernel_name = kernel.name + else: + kernel_name = name + + if self._kernel_namespace.find_kernel(kernel_name) is not None: + raise ValueError( + f"Duplicate kernels: A kernel called {kernel_name} already exists " + f"in namespace {self._kernel_namespace.fqname}" + ) + + if name is not None: + kernel.name = kernel_name + + khandle = SfgKernelHandle(kernel_name, self._kernel_namespace, kernel) + self._kernel_namespace.add_kernel(khandle) + + self._loc.elements.append(SfgEntityDef(khandle)) + + for header in kernel.required_headers: + assert self._ctx.impl_file is not None + self._ctx.impl_file.includes.append(HeaderFile.parse(header)) + + return khandle + + def create( + self, + assignments, + name: str | None = None, + config: CreateKernelConfig | None = None, + ): + """Creates a new pystencils kernel from a list of assignments and a configuration. + This is a wrapper around `pystencils.create_kernel` + with a subsequent call to `add`. + """ + if config is None: + config = CreateKernelConfig() + + if name is not None: + if self._kernel_namespace.find_kernel(name) is not None: + raise ValueError( + f"Duplicate kernels: A kernel called {name} already exists " + f"in namespace {self._kernel_namespace.fqname}" + ) + + config.function_name = name + + # type: ignore + kernel = create_kernel(assignments, config=config) + return self.add(kernel) + + class SfgBasicComposer(SfgIComposer): """Composer for basic source components, and base class for all composer mix-ins.""" @@ -86,7 +147,7 @@ class SfgBasicComposer(SfgIComposer): ctx: SfgContext = sfg if isinstance(sfg, SfgContext) else sfg.context super().__init__(ctx) - def prelude(self, content: str): + def prelude(self, content: str, end: str = "\n"): """Append a string to the prelude comment, to be printed at the top of both generated files. The string should not contain C/C++ comment delimiters, since these will be added automatically @@ -104,7 +165,11 @@ class SfgBasicComposer(SfgIComposer): */ """ - self._ctx.append_to_prelude(content) + for f in self._ctx.files: + if f.prelude is None: + f.prelude = content + end + else: + f.prelude += content + end def code(self, *code: str): """Add arbitrary lines of code to the generated header file. @@ -125,7 +190,7 @@ class SfgBasicComposer(SfgIComposer): """ for c in code: - self._ctx.add_definition(c) + self._cursor.write_header(c) def define(self, *definitions: str): from warnings import warn @@ -138,34 +203,41 @@ class SfgBasicComposer(SfgIComposer): self.code(*definitions) - def define_once(self, *definitions: str): - """Add unique definitions to the header file. + def namespace(self, namespace: str): + """Enter a new namespace block. - Each code string given to `define_once` will only be added if the exact same string - was not already added before. - """ - for definition in definitions: - if all(d != definition for d in self._ctx.definitions()): - self._ctx.add_definition(definition) + Calling `namespace` as a regular function will open a new namespace as a child of the + currently active namespace; this new namespace will then become active instead. + Using `namespace` as a context manager will instead activate the given namespace + only for the length of the ``with`` block. - def namespace(self, namespace: str): - """Set the inner code namespace. Throws an exception if a namespace was already set. + Args: + namespace: Qualified name of the namespace :Example: - After adding the following to your generator script: + The following calls will set the current namespace to ``outer::inner`` + for the remaining code generation run: - >>> sfg.namespace("codegen_is_awesome") + .. code-block:: - All generated code will be placed within that namespace: + sfg.namespace("outer") + sfg.namespace("inner") - .. code-block:: C++ + Subsequent calls to `namespace` can only create further nested namespaces. + + To step back out of a namespace, `namespace` can also be used as a context manager: + + .. code-block:: + + with sfg.namespace("detail"): + ... + + This way, code generated inside the ``with`` region is placed in the ``detail`` namespace, + and code after this block will again live in the enclosing namespace. - namespace codegen_is_awesome { - /* all generated code */ - } """ - self._ctx.set_namespace(namespace) + return self._cursor.enter_namespace(namespace) def generate(self, generator: CustomGenerator): """Invoke a custom code generator with the underlying context.""" @@ -174,7 +246,7 @@ class SfgBasicComposer(SfgIComposer): generator.generate(SfgComposer(self)) @property - def kernels(self) -> SfgKernelNamespace: + def kernels(self) -> KernelsAdder: """The default kernel namespace. Add kernels like:: @@ -182,18 +254,24 @@ class SfgBasicComposer(SfgIComposer): sfg.kernels.add(ast, "kernel_name") sfg.kernels.create(assignments, "kernel_name", config) """ - return self._ctx._default_kernel_namespace + return self.kernel_namespace("kernels") - def kernel_namespace(self, name: str) -> SfgKernelNamespace: + def kernel_namespace(self, name: str) -> KernelsAdder: """Return the kernel namespace of the given name, creating it if it does not exist yet.""" - kns = self._ctx.get_kernel_namespace(name) + kns = self._cursor.get_entity("kernels") if kns is None: - kns = SfgKernelNamespace(self._ctx, name) - self._ctx.add_kernel_namespace(kns) + kns = SfgKernelNamespace("kernels", self._cursor.current_namespace) + self._cursor.add_entity(kns) + elif not isinstance(kns, SfgKernelNamespace): + raise ValueError( + f"The existing entity {kns.fqname} is not a kernel namespace" + ) - return kns + kns_block = SfgNamespaceBlock(kns) + self._cursor.write_impl(kns_block) + return KernelsAdder(self._ctx, kns_block) - def include(self, header_file: str, private: bool = False): + def include(self, header: str | HeaderFile, private: bool = False): """Include a header file. Args: @@ -213,46 +291,37 @@ class SfgBasicComposer(SfgIComposer): #include <vector> #include "custom.h" """ - self._ctx.add_include(SfgHeaderInclude(HeaderFile.parse(header_file), private)) - - def numpy_struct( - self, name: str, dtype: np.dtype, add_constructor: bool = True - ) -> SfgClass: - """Add a numpy structured data type as a C++ struct - - Returns: - The created class object - """ - if self._ctx.get_class(name) is not None: - raise SfgException(f"Class with name {name} already exists.") - - cls = _struct_from_numpy_dtype(name, dtype, add_constructor=add_constructor) - self._ctx.add_class(cls) - return cls + header_file = HeaderFile.parse(header) + + if private: + if self._ctx.impl_file is None: + raise ValueError( + "Cannot emit a private include since no implementation file is being generated" + ) + self._ctx.impl_file.includes.append(header_file) + else: + self._ctx.header_file.includes.append(header_file) - def kernel_function( - self, name: str, ast_or_kernel_handle: Kernel | SfgKernelHandle - ): + def kernel_function(self, name: str, kernel: Kernel | SfgKernelHandle): """Create a function comprising just a single kernel call. Args: ast_or_kernel_handle: Either a pystencils AST, or a kernel handle for an already registered AST. """ - if self._ctx.get_function(name) is not None: - raise ValueError(f"Function {name} already exists.") - - if isinstance(ast_or_kernel_handle, Kernel): - khandle = self._ctx.default_kernel_namespace.add(ast_or_kernel_handle) - tree = SfgKernelCallNode(khandle) - elif isinstance(ast_or_kernel_handle, SfgKernelHandle): - tree = SfgKernelCallNode(ast_or_kernel_handle) + if isinstance(kernel, Kernel): + khandle = self.kernels.add(kernel, name) else: - raise TypeError("Invalid type of argument `ast_or_kernel_handle`!") + khandle = kernel - func = SfgFunction(name, tree) - self._ctx.add_function(func) + self.function(name)(self.call(khandle)) - def function(self, name: str, return_type: UserTypeSpec = void): + def function( + self, + name: str, + returns: UserTypeSpec = void, + inline: bool = False, + return_type: UserTypeSpec | None = None, + ): """Add a function. The syntax of this function adder uses a chain of two calls to mimic C++ syntax: @@ -265,13 +334,31 @@ class SfgBasicComposer(SfgIComposer): The function body is constructed via sequencing (see `make_sequence`). """ - if self._ctx.get_function(name) is not None: - raise ValueError(f"Function {name} already exists.") + if return_type is not None: + warn( + "The parameter `return_type` to `function()` is deprecated and will be removed by version 0.1. " + "Setting it will override the value of the `returns` parameter. " + "Use `returns` instead.", + FutureWarning, + ) + returns = return_type def sequencer(*args: SequencerArg): tree = make_sequence(*args) - func = SfgFunction(name, tree, return_type=create_type(return_type)) - self._ctx.add_function(func) + func = SfgFunction( + name, + self._cursor.current_namespace, + tree, + return_type=create_type(returns), + inline=inline, + ) + self._cursor.add_entity(func) + + if inline: + self._cursor.write_header(SfgEntityDef(func)) + else: + self._cursor.write_header(SfgEntityDecl(func)) + self._cursor.write_impl(SfgEntityDef(func)) return sequencer @@ -610,33 +697,3 @@ class SfgSwitchBuilder(SfgNodeBuilder): def resolve(self) -> SfgCallTreeNode: return SfgSwitch(make_statements(self._switch_arg), self._cases, self._default) - - -def _struct_from_numpy_dtype( - struct_name: str, dtype: np.dtype, add_constructor: bool = True -): - cls = SfgClass(struct_name, class_keyword=SfgClassKeyword.STRUCT) - - fields = dtype.fields - if fields is None: - raise SfgException(f"Numpy dtype {dtype} is not a structured type.") - - constr_params = [] - constr_inits = [] - - for member_name, type_info in fields.items(): - member_type = create_type(type_info[0]) - - member = SfgMemberVariable(member_name, member_type) - - arg = SfgVar(f"{member_name}_", member_type) - - cls.default.append_member(member) - - constr_params.append(arg) - constr_inits.append(f"{member}({arg})") - - if add_constructor: - cls.default.append_member(SfgConstructor(constr_params, constr_inits)) - - return cls diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index 489823b9ce619be88e3220ce4b941cf49c62b298..0a72e8089ecd5e32be53cd335df57b58b21ec578 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -1,26 +1,30 @@ from __future__ import annotations from typing import Sequence +from itertools import takewhile, dropwhile +import numpy as np from pystencils.types import PsCustomType, UserTypeSpec, create_type +from ..context import SfgContext from ..lang import ( - _VarLike, VarLike, ExprLike, asvar, SfgVar, ) -from ..ir.source_components import ( +from ..ir import ( + SfgCallTreeNode, SfgClass, - SfgClassMember, - SfgInClassDefinition, SfgConstructor, SfgMethod, SfgMemberVariable, SfgClassKeyword, SfgVisibility, SfgVisibilityBlock, + SfgEntityDecl, + SfgEntityDef, + SfgClassBody, ) from ..exceptions import SfgException @@ -40,31 +44,88 @@ class SfgClassComposer(SfgComposerMixIn): Its interface is exposed by :class:`SfgComposer`. """ - class VisibilityContext: + class VisibilityBlockSequencer: """Represent a visibility block in the composer syntax. Returned by `private`, `public`, and `protected`. """ def __init__(self, visibility: SfgVisibility): - self._vis_block = SfgVisibilityBlock(visibility) - - def members(self): - yield from self._vis_block.members() + self._visibility = visibility + self._args: tuple[ + SfgClassComposer.MethodSequencer + | SfgClassComposer.ConstructorBuilder + | VarLike + | str, + ..., + ] def __call__( self, *args: ( - SfgClassMember | SfgClassComposer.ConstructorBuilder | VarLike | str + SfgClassComposer.MethodSequencer + | SfgClassComposer.ConstructorBuilder + | VarLike + | str ), ): - for arg in args: - self._vis_block.append_member(SfgClassComposer._resolve_member(arg)) + self._args = args + return self + def _resolve(self, ctx: SfgContext, cls: SfgClass) -> SfgVisibilityBlock: + vis_block = SfgVisibilityBlock(self._visibility) + for arg in self._args: + match arg: + case ( + SfgClassComposer.MethodSequencer() + | SfgClassComposer.ConstructorBuilder() + ): + arg._resolve(ctx, cls, vis_block) + case str(): + vis_block.elements.append(arg) + case _: + var = asvar(arg) + member_var = SfgMemberVariable(var.name, var.dtype, cls) + cls.add_member(member_var, vis_block.visibility) + vis_block.elements.append(SfgEntityDef(member_var)) + return vis_block + + class MethodSequencer: + def __init__( + self, + name: str, + returns: UserTypeSpec = PsCustomType("void"), + inline: bool = False, + const: bool = False, + ) -> None: + self._name = name + self._returns = create_type(returns) + self._inline = inline + self._const = const + self._tree: SfgCallTreeNode + + def __call__(self, *args: SequencerArg): + self._tree = make_sequence(*args) return self - def resolve(self, cls: SfgClass) -> None: - cls.append_visibility_block(self._vis_block) + def _resolve( + self, ctx: SfgContext, cls: SfgClass, vis_block: SfgVisibilityBlock + ): + method = SfgMethod( + self._name, + cls, + self._tree, + return_type=self._returns, + inline=self._inline, + const=self._const, + ) + cls.add_member(method, vis_block.visibility) + + if self._inline: + vis_block.elements.append(SfgEntityDef(method)) + else: + vis_block.elements.append(SfgEntityDecl(method)) + ctx._cursor.write_impl(SfgEntityDef(method)) class ConstructorBuilder: """Composer syntax for constructor building. @@ -74,7 +135,7 @@ class SfgClassComposer(SfgComposerMixIn): def __init__(self, *params: VarLike): self._params = list(asvar(p) for p in params) - self._initializers: list[str] = [] + self._initializers: list[tuple[SfgVar | str, tuple[ExprLike, ...]]] = [] self._body: str | None = None def add_param(self, param: VarLike, at: int | None = None): @@ -93,9 +154,7 @@ class SfgClassComposer(SfgComposerMixIn): member = var if isinstance(var, str) else asvar(var) def init_sequencer(*args: ExprLike): - expr = ", ".join(str(arg) for arg in args) - initializer = f"{member}{{ {expr} }}" - self._initializers.append(initializer) + self._initializers.append((member, args)) return self return init_sequencer @@ -107,13 +166,19 @@ class SfgClassComposer(SfgComposerMixIn): self._body = body return self - def resolve(self) -> SfgConstructor: - return SfgConstructor( + def _resolve( + self, ctx: SfgContext, cls: SfgClass, vis_block: SfgVisibilityBlock + ): + ctor = SfgConstructor( + cls, parameters=self._params, initializers=self._initializers, body=self._body if self._body is not None else "", ) + cls.add_member(ctor, vis_block.visibility) + vis_block.elements.append(SfgEntityDef(ctor)) + def klass(self, class_name: str, bases: Sequence[str] = ()): """Create a class and add it to the underlying context. @@ -132,20 +197,30 @@ class SfgClassComposer(SfgComposerMixIn): """ return self._class(class_name, SfgClassKeyword.STRUCT, bases) + def numpy_struct( + self, name: str, dtype: np.dtype, add_constructor: bool = True + ): + """Add a numpy structured data type as a C++ struct + + Returns: + The created class object + """ + return self._struct_from_numpy_dtype(name, dtype, add_constructor) + @property - def public(self) -> SfgClassComposer.VisibilityContext: + def public(self) -> SfgClassComposer.VisibilityBlockSequencer: """Create a `public` visibility block in a class body""" - return SfgClassComposer.VisibilityContext(SfgVisibility.PUBLIC) + return SfgClassComposer.VisibilityBlockSequencer(SfgVisibility.PUBLIC) @property - def protected(self) -> SfgClassComposer.VisibilityContext: + def protected(self) -> SfgClassComposer.VisibilityBlockSequencer: """Create a `protected` visibility block in a class or struct body""" - return SfgClassComposer.VisibilityContext(SfgVisibility.PROTECTED) + return SfgClassComposer.VisibilityBlockSequencer(SfgVisibility.PROTECTED) @property - def private(self) -> SfgClassComposer.VisibilityContext: + def private(self) -> SfgClassComposer.VisibilityBlockSequencer: """Create a `private` visibility block in a class or struct body""" - return SfgClassComposer.VisibilityContext(SfgVisibility.PRIVATE) + return SfgClassComposer.VisibilityBlockSequencer(SfgVisibility.PRIVATE) def constructor(self, *params: VarLike): """In a class or struct body or visibility block, add a constructor. @@ -172,76 +247,85 @@ class SfgClassComposer(SfgComposerMixIn): const: Whether or not the method is const-qualified. """ - def sequencer(*args: SequencerArg): - tree = make_sequence(*args) - return SfgMethod( - name, - tree, - return_type=create_type(returns), - inline=inline, - const=const, - ) - - return sequencer + return SfgClassComposer.MethodSequencer(name, returns, inline, const) # INTERNALS def _class(self, class_name: str, keyword: SfgClassKeyword, bases: Sequence[str]): - if self._ctx.get_class(class_name) is not None: - raise ValueError(f"Class or struct {class_name} already exists.") + # TODO: Return a `CppClass` instance representing the generated class - cls = SfgClass(class_name, class_keyword=keyword, bases=bases) - self._ctx.add_class(cls) + if self._cursor.get_entity(class_name) is not None: + raise ValueError( + f"Another entity with name {class_name} already exists in the current namespace." + ) + + cls = SfgClass( + class_name, + self._cursor.current_namespace, + class_keyword=keyword, + bases=bases, + ) + self._cursor.add_entity(cls) def sequencer( *args: ( - SfgClassComposer.VisibilityContext - | SfgClassMember + SfgClassComposer.VisibilityBlockSequencer + | SfgClassComposer.MethodSequencer | SfgClassComposer.ConstructorBuilder | VarLike | str ), ): - default_ended = False - - for arg in args: - if isinstance(arg, SfgClassComposer.VisibilityContext): - default_ended = True - arg.resolve(cls) - elif isinstance( - arg, - ( - SfgClassMember, - SfgClassComposer.ConstructorBuilder, - str, - ) - + _VarLike, - ): - if default_ended: - raise SfgException( - "Composer Syntax Error: " - "Cannot add members with default visibility after a visibility block." - ) - else: - cls.default.append_member(self._resolve_member(arg)) + default_vis_sequencer = SfgClassComposer.VisibilityBlockSequencer( + SfgVisibility.DEFAULT + ) + + def argfilter(arg): + return not isinstance(arg, SfgClassComposer.VisibilityBlockSequencer) + + default_vis_args = takewhile( + argfilter, + args, + ) + default_block = default_vis_sequencer(*default_vis_args)._resolve(self._ctx, cls) # type: ignore + vis_blocks: list[SfgVisibilityBlock] = [] + + for arg in dropwhile(argfilter, args): + if isinstance(arg, SfgClassComposer.VisibilityBlockSequencer): + vis_blocks.append(arg._resolve(self._ctx, cls)) else: - raise SfgException(f"{arg} is not a valid class member.") + raise SfgException( + "Composer Syntax Error: " + "Cannot add members with default visibility after a visibility block." + ) + + self._cursor.write_header(SfgClassBody(cls, default_block, vis_blocks)) return sequencer - @staticmethod - def _resolve_member( - arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | VarLike | str, - ) -> SfgClassMember: - match arg: - case _ if isinstance(arg, _VarLike): - var = asvar(arg) - return SfgMemberVariable(var.name, var.dtype) - case str(): - return SfgInClassDefinition(arg) - case SfgClassComposer.ConstructorBuilder(): - return arg.resolve() - case SfgClassMember(): - return arg - case _: - raise ValueError(f"Invalid class member: {arg}") + def _struct_from_numpy_dtype( + self, struct_name: str, dtype: np.dtype, add_constructor: bool = True + ): + fields = dtype.fields + if fields is None: + raise SfgException(f"Numpy dtype {dtype} is not a structured type.") + + members: list[SfgClassComposer.ConstructorBuilder | SfgVar] = [] + if add_constructor: + ctor = self.constructor() + members.append(ctor) + + for member_name, type_info in fields.items(): + member_type = create_type(type_info[0]) + + member = SfgVar(member_name, member_type) + members.append(member) + + if add_constructor: + arg = SfgVar(f"{member_name}_", member_type) + ctor.add_param(arg) + ctor.init(member)(arg) + + return self.struct( + struct_name, + )(*members) diff --git a/src/pystencilssfg/composer/mixin.py b/src/pystencilssfg/composer/mixin.py index 3ee8efa61227184686001045bf3f8cb23525bf02..34b1c5856a3b923f4b15939110bf3bfbc0ba5970 100644 --- a/src/pystencilssfg/composer/mixin.py +++ b/src/pystencilssfg/composer/mixin.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ..context import SfgContext +from ..context import SfgContext, SfgCursor from .basic_composer import SfgBasicComposer @@ -14,6 +14,7 @@ class SfgComposerMixIn: def __init__(self) -> None: self._ctx: SfgContext + self._cursor: SfgCursor @property def _composer(self) -> SfgBasicComposer: diff --git a/src/pystencilssfg/config.py b/src/pystencilssfg/config.py index aae9dab541f95c2d7af09da46bce1346150bb4a3..3b63f4f121d27b53a3973a93996a6d0a99fe09e6 100644 --- a/src/pystencilssfg/config.py +++ b/src/pystencilssfg/config.py @@ -3,14 +3,16 @@ from __future__ import annotations from argparse import ArgumentParser from types import ModuleType -from typing import Any, Sequence +from typing import Any, Sequence, Callable from dataclasses import dataclass from enum import Enum, auto from os import path from importlib import util as iutil +from pathlib import Path +from pystencils.codegen.config import ConfigBase, Option, BasicOption, Category -from pystencils.codegen.config import ConfigBase, BasicOption, Category +from .lang import HeaderFile class SfgConfigException(Exception): ... # noqa: E701 @@ -61,6 +63,12 @@ class CodeStyle(ConfigBase): indent_width: BasicOption[int] = BasicOption(2) """The number of spaces successively nested blocks should be indented with""" + includes_sorting_key: BasicOption[Callable[[HeaderFile], Any]] = BasicOption() + """Key function that will be used to sort `#include` statements in generated files. + + Pystencils-sfg will instruct clang-tidy to forego include sorting if this option is set. + """ + # TODO possible future options: # - newline before opening { # - trailing return types @@ -166,9 +174,34 @@ class SfgConfig(ConfigBase): ClangFormatOptions.binary """ - output_directory: BasicOption[str] = BasicOption(".") + output_directory: Option[Path, str | Path] = Option(Path(".")) """Directory to which the generated files should be written.""" + @output_directory.validate + def _validate_output_directory(self, pth: str | Path) -> Path: + return Path(pth) + + def _get_output_files(self, basename: str): + output_dir: Path = self.get_option("output_directory") + + header_ext = self.extensions.get_option("header") + impl_ext = self.extensions.get_option("impl") + output_files = [output_dir / f"{basename}.{header_ext}"] + output_mode = self.get_option("output_mode") + + if impl_ext is None: + match output_mode: + case OutputMode.INLINE: + impl_ext = "ipp" + case OutputMode.STANDALONE: + impl_ext = "cpp" + + if output_mode != OutputMode.HEADER_ONLY: + assert impl_ext is not None + output_files.append(output_dir / f"{basename}.{impl_ext}") + + return tuple(output_files) + class CommandLineParameters: @staticmethod diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index 17537a26706be5e5a9c93bc9cc5c09bb13c48dff..199c678ba28e449f42b29c56147b9a4fd0d523bb 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -1,83 +1,51 @@ -from typing import Generator, Sequence, Any +from __future__ import annotations +from typing import Sequence, Any, Generator +from contextlib import contextmanager from .config import CodeStyle -from .ir.source_components import ( - SfgHeaderInclude, - SfgKernelNamespace, - SfgFunction, - SfgClass, +from .ir import ( + SfgSourceFile, + SfgNamespace, + SfgNamespaceBlock, + SfgCodeEntity, + SfgGlobalNamespace, ) +from .ir.syntax import SfgNamespaceElement from .exceptions import SfgException class SfgContext: - """Represents a header/implementation file pair in the code generator. - - **Source File Properties and Components** - - The SfgContext collects all properties and components of a header/implementation - file pair (or just the header file, if header-only generation is used). - These are: - - - The code namespace, which is combined from the `outer_namespace` - and the `pystencilssfg.SfgContext.inner_namespace`. The outer namespace is meant to be set - externally e.g. by the project configuration, while the inner namespace is meant to be set by the generator - script. - - The `prelude comment` is a block of text printed as a comment block - at the top of both generated files. Typically, it contains authorship and licence information. - - The set of included header files (`pystencilssfg.SfgContext.includes`). - - Custom `definitions`, which are just arbitrary code strings. - - Any number of kernel namespaces (`pystencilssfg.SfgContext.kernel_namespaces`), within which *pystencils* - kernels are managed. - - Any number of functions (`pystencilssfg.SfgContext.functions`), which are meant to serve as wrappers - around kernel calls. - - Any number of classes (`pystencilssfg.SfgContext.classes`), which can be used to build more extensive wrappers - around kernels. - - **Order of Definitions** - - To honor C/C++ use-after-declare rules, the context preserves the order in which definitions, functions and classes - are added to it. - The header file printers implemented in *pystencils-sfg* will print the declarations accordingly. - The declarations can retrieved in order of definition via `declarations_ordered`. - """ + """Manages context information during the execution of a generator script.""" def __init__( self, - outer_namespace: str | None = None, + header_file: SfgSourceFile, + impl_file: SfgSourceFile | None, + namespace: str | None = None, codestyle: CodeStyle | None = None, argv: Sequence[str] | None = None, project_info: Any = None, ): - """ - Args: - outer_namespace: Qualified name of the outer code namespace - codestyle: Code style that should be used by the code emitter - argv: The generator script's command line arguments. - Reserved for internal use by the [SourceFileGenerator][pystencilssfg.SourceFileGenerator]. - project_info: Project-specific information provided by a build system. - Reserved for internal use by the [SourceFileGenerator][pystencilssfg.SourceFileGenerator]. - """ self._argv = argv self._project_info = project_info - self._default_kernel_namespace = SfgKernelNamespace(self, "kernels") - self._outer_namespace = outer_namespace + self._outer_namespace = namespace self._inner_namespace: str | None = None self._codestyle = codestyle if codestyle is not None else CodeStyle() - # Source Components - self._prelude: str = "" - self._includes: list[SfgHeaderInclude] = [] - self._definitions: list[str] = [] - self._kernel_namespaces = { - self._default_kernel_namespace.name: self._default_kernel_namespace - } - self._functions: dict[str, SfgFunction] = dict() - self._classes: dict[str, SfgClass] = dict() + self._header_file = header_file + self._impl_file = impl_file - self._declarations_ordered: list[str | SfgFunction | SfgClass] = list() + self._global_namespace = SfgGlobalNamespace() + + current_namespace: SfgNamespace + if namespace is not None: + current_namespace = self._global_namespace.get_child_namespace(namespace) + else: + current_namespace = self._global_namespace + + self._cursor = SfgCursor(self, current_namespace) @property def argv(self) -> Sequence[str]: @@ -100,163 +68,90 @@ class SfgContext: """Outer code namespace. Set by constructor argument `outer_namespace`.""" return self._outer_namespace - @property - def inner_namespace(self) -> str | None: - """Inner code namespace. Set by `set_namespace`.""" - return self._inner_namespace - - @property - def fully_qualified_namespace(self) -> str | None: - """Combined outer and inner namespaces, as `outer_namespace::inner_namespace`.""" - match (self.outer_namespace, self.inner_namespace): - case None, None: - return None - case outer, None: - return outer - case None, inner: - return inner - case outer, inner: - return f"{outer}::{inner}" - case _: - assert False - @property def codestyle(self) -> CodeStyle: """The code style object for this generation context.""" return self._codestyle - # ---------------------------------------------------------------------------------------------- - # Prelude, Includes, Definitions, Namespace - # ---------------------------------------------------------------------------------------------- - @property - def prelude_comment(self) -> str: - """The prelude is a comment block printed at the top of both generated files.""" - return self._prelude - - def append_to_prelude(self, code_str: str): - """Append a string to the prelude comment. - - The string should not contain - C/C++ comment delimiters, since these will be added automatically during - code generation. - """ - if self._prelude: - self._prelude += "\n" - - self._prelude += code_str - - if not code_str.endswith("\n"): - self._prelude += "\n" - - def includes(self) -> Generator[SfgHeaderInclude, None, None]: - """Includes of headers. Public includes are added to the header file, private includes - are added to the implementation file.""" - yield from self._includes - - def add_include(self, include: SfgHeaderInclude): - self._includes.append(include) - - def definitions(self) -> Generator[str, None, None]: - """Definitions are arbitrary custom lines of code.""" - yield from self._definitions - - def add_definition(self, definition: str): - """Add a custom code string to the header file.""" - self._definitions.append(definition) - self._declarations_ordered.append(definition) - - def set_namespace(self, namespace: str): - """Set the inner code namespace. - - Throws an exception if the namespace was already set. - """ - if self._inner_namespace is not None: - raise SfgException("The code namespace was already set.") - - self._inner_namespace = namespace - - # ---------------------------------------------------------------------------------------------- - # Kernel Namespaces - # ---------------------------------------------------------------------------------------------- + def header_file(self) -> SfgSourceFile: + return self._header_file @property - def default_kernel_namespace(self) -> SfgKernelNamespace: - """The default kernel namespace.""" - return self._default_kernel_namespace - - def kernel_namespaces(self) -> Generator[SfgKernelNamespace, None, None]: - """Iterator over all registered kernel namespaces.""" - yield from self._kernel_namespaces.values() + def impl_file(self) -> SfgSourceFile | None: + return self._impl_file - def get_kernel_namespace(self, str) -> SfgKernelNamespace | None: - """Retrieve a kernel namespace by name, or `None` if it does not exist.""" - return self._kernel_namespaces.get(str) - - def add_kernel_namespace(self, namespace: SfgKernelNamespace): - """Adds a new kernel namespace. + @property + def cursor(self) -> SfgCursor: + return self._cursor - If a kernel namespace of the same name already exists, throws an exception. - """ - if namespace.name in self._kernel_namespaces: - raise ValueError(f"Duplicate kernel namespace: {namespace.name}") + @property + def files(self) -> Generator[SfgSourceFile, None, None]: + yield self._header_file + if self._impl_file is not None: + yield self._impl_file - self._kernel_namespaces[namespace.name] = namespace + @property + def global_namespace(self) -> SfgNamespace: + return self._global_namespace - # ---------------------------------------------------------------------------------------------- - # Functions - # ---------------------------------------------------------------------------------------------- - def functions(self) -> Generator[SfgFunction, None, None]: - """Iterator over all registered functions.""" - yield from self._functions.values() +class SfgCursor: + """Cursor that tracks the current location in the source file(s) during execution of the generator script.""" - def get_function(self, name: str) -> SfgFunction | None: - """Retrieve a function by name. Returns `None` if no function of the given name exists.""" - return self._functions.get(name, None) + def __init__(self, ctx: SfgContext, namespace: SfgNamespace) -> None: + self._ctx = ctx - def add_function(self, func: SfgFunction): - """Adds a new function. + self._cur_namespace: SfgNamespace = namespace - If a function or class with the same name exists already, throws an exception. - """ - if func.name in self._functions or func.name in self._classes: - raise SfgException(f"Duplicate function: {func.name}") + self._loc: dict[SfgSourceFile, list[SfgNamespaceElement]] = dict() + for f in self._ctx.files: + if not isinstance(namespace, SfgGlobalNamespace): + block = SfgNamespaceBlock( + self._cur_namespace, self._cur_namespace.fqname + ) + f.elements.append(block) + self._loc[f] = block.elements + else: + self._loc[f] = f.elements - self._functions[func.name] = func - self._declarations_ordered.append(func) + @property + def current_namespace(self) -> SfgNamespace: + return self._cur_namespace - # ---------------------------------------------------------------------------------------------- - # Classes - # ---------------------------------------------------------------------------------------------- + def get_entity(self, name: str) -> SfgCodeEntity | None: + return self._cur_namespace.get_entity(name) - def classes(self) -> Generator[SfgClass, None, None]: - """Iterator over all registered classes.""" - yield from self._classes.values() + def add_entity(self, entity: SfgCodeEntity): + self._cur_namespace.add_entity(entity) - def get_class(self, name: str) -> SfgClass | None: - """Retrieve a class by name, or `None` if the class does not exist.""" - return self._classes.get(name, None) + def write_header(self, elem: SfgNamespaceElement) -> None: + self._loc[self._ctx.header_file].append(elem) - def add_class(self, cls: SfgClass): - """Add a class. + def write_impl(self, elem: SfgNamespaceElement) -> None: + impl_file = self._ctx.impl_file + if impl_file is None: + raise SfgException( + f"Cannot write element {elem} to implemenation file since no implementation file is being generated." + ) + self._loc[impl_file].append(elem) - Throws an exception if a class or function of the same name exists already. - """ - if cls.class_name in self._classes or cls.class_name in self._functions: - raise SfgException(f"Duplicate class: {cls.class_name}") + def enter_namespace(self, qual_name: str): + namespace = self._cur_namespace.get_child_namespace(qual_name) - self._classes[cls.class_name] = cls - self._declarations_ordered.append(cls) + outer_locs = self._loc.copy() - # ---------------------------------------------------------------------------------------------- - # Declarations in order of addition - # ---------------------------------------------------------------------------------------------- + for f in self._ctx.files: + block = SfgNamespaceBlock(namespace, qual_name) + self._loc[f].append(block) + self._loc[f] = block.elements - def declarations_ordered( - self, - ) -> Generator[str | SfgFunction | SfgClass, None, None]: - """All declared definitions, classes and functions in the order they were added. + @contextmanager + def ctxmgr(): + try: + yield None + finally: + # Have the cursor step back out of the nested namespace blocks + self._loc = outer_locs - Awareness about order is necessary due to the C++ declare-before-use rules.""" - yield from self._declarations_ordered + return ctxmgr() diff --git a/src/pystencilssfg/emission/__init__.py b/src/pystencilssfg/emission/__init__.py index fd666283edd12d49c02eb42eee90a5ae8f9756dd..1a22aa26a0bf79acec636488ec6c5e9c00834b78 100644 --- a/src/pystencilssfg/emission/__init__.py +++ b/src/pystencilssfg/emission/__init__.py @@ -1,5 +1,4 @@ -from .emitter import AbstractEmitter, OutputSpec -from .header_impl_pair import HeaderImplPairEmitter -from .header_only import HeaderOnlyEmitter +from .emitter import SfgCodeEmitter +from .file_printer import SfgFilePrinter -__all__ = ["AbstractEmitter", "OutputSpec", "HeaderImplPairEmitter", "HeaderOnlyEmitter"] +__all__ = ["SfgCodeEmitter", "SfgFilePrinter"] diff --git a/src/pystencilssfg/emission/clang_format.py b/src/pystencilssfg/emission/clang_format.py index 1b15e8c4145db67656b29acdd456074ad2d0a895..50c51f176cbbc8abf3872fab319b3069af58141e 100644 --- a/src/pystencilssfg/emission/clang_format.py +++ b/src/pystencilssfg/emission/clang_format.py @@ -5,14 +5,16 @@ from ..config import ClangFormatOptions from ..exceptions import SfgException -def invoke_clang_format(code: str, options: ClangFormatOptions) -> str: +def invoke_clang_format( + code: str, options: ClangFormatOptions, sort_includes: str | None = None +) -> str: """Call the `clang-format` command-line tool to format the given code string according to the given style arguments. Args: code: Code string to format - codestyle: [SfgCodeStyle][pystencilssfg.configuration.SfgCodeStyle] object - defining the `clang-format` binary and the desired code style. + options: Options controlling the clang-format invocation + sort_includes: Option to be passed on to clang-format's ``--sort-includes`` argument Returns: The formatted code, if `clang-format` was run sucessfully. @@ -32,6 +34,9 @@ def invoke_clang_format(code: str, options: ClangFormatOptions) -> str: style = options.get_option("code_style") args = [binary, f"--style={style}"] + if sort_includes is not None: + args += ["--sort-includes", sort_includes] + if not shutil.which(binary): if force: raise SfgException( diff --git a/src/pystencilssfg/emission/emitter.py b/src/pystencilssfg/emission/emitter.py index c32b18af351579b98477bf688c0789a394fc3732..c1b6e9c79e09bf1d5604dbd6fca9304190e85272 100644 --- a/src/pystencilssfg/emission/emitter.py +++ b/src/pystencilssfg/emission/emitter.py @@ -1,71 +1,38 @@ from __future__ import annotations -from typing import Sequence -from abc import ABC, abstractmethod -from dataclasses import dataclass -from os import path +from pathlib import Path -from ..context import SfgContext -from ..config import SfgConfig, OutputMode +from ..config import CodeStyle, ClangFormatOptions +from ..ir import SfgSourceFile +from .file_printer import SfgFilePrinter +from .clang_format import invoke_clang_format -@dataclass -class OutputSpec: - """Name and path specification for files output by the code generator. - Filenames are constructed as `<output_directory>/<basename>.<extension>`.""" +class SfgCodeEmitter: + def __init__( + self, + output_directory: Path, + code_style: CodeStyle, + clang_format: ClangFormatOptions, + ): + self._output_dir = output_directory + self._code_style = code_style + self._clang_format_opts = clang_format + self._printer = SfgFilePrinter(code_style) - output_directory: str - """Directory to which the generated files should be written.""" + def emit(self, file: SfgSourceFile): + code = self._printer(file) - basename: str - """Base name for output files.""" + if self._code_style.get_option("includes_sorting_key") is not None: + sort_includes = "Never" + else: + sort_includes = None - header_extension: str - """File extension for generated header file.""" - - impl_extension: str - """File extension for generated implementation file.""" - - def get_header_filename(self): - return f"{self.basename}.{self.header_extension}" - - def get_impl_filename(self): - return f"{self.basename}.{self.impl_extension}" - - def get_header_filepath(self): - return path.join(self.output_directory, self.get_header_filename()) - - def get_impl_filepath(self): - return path.join(self.output_directory, self.get_impl_filename()) - - @staticmethod - def create(config: SfgConfig, basename: str) -> OutputSpec: - output_mode = config.get_option("output_mode") - header_extension = config.extensions.get_option("header") - impl_extension = config.extensions.get_option("impl") - - if impl_extension is None: - match output_mode: - case OutputMode.INLINE: - impl_extension = "ipp" - case OutputMode.STANDALONE: - impl_extension = "cpp" - - return OutputSpec( - config.get_option("output_directory"), - basename, - header_extension, - impl_extension, + code = invoke_clang_format( + code, self._clang_format_opts, sort_includes=sort_includes ) - -class AbstractEmitter(ABC): - @property - @abstractmethod - def output_files(self) -> Sequence[str]: - pass - - @abstractmethod - def write_files(self, ctx: SfgContext): - pass + self._output_dir.mkdir(parents=True, exist_ok=True) + fpath = self._output_dir / file.name + fpath.write_text(code) diff --git a/src/pystencilssfg/emission/file_printer.py b/src/pystencilssfg/emission/file_printer.py new file mode 100644 index 0000000000000000000000000000000000000000..8216a7b1923b8b23e19a4fbece8d507aa6260d27 --- /dev/null +++ b/src/pystencilssfg/emission/file_printer.py @@ -0,0 +1,190 @@ +from __future__ import annotations +from textwrap import indent + +from pystencils.backend.emission import CAstPrinter + +from ..ir import ( + SfgSourceFile, + SfgSourceFileType, + SfgNamespaceBlock, + SfgEntityDecl, + SfgEntityDef, + SfgKernelHandle, + SfgFunction, + SfgClassMember, + SfgMethod, + SfgMemberVariable, + SfgConstructor, + SfgClass, + SfgClassBody, + SfgVisibilityBlock, + SfgVisibility, +) +from ..ir.syntax import SfgNamespaceElement, SfgClassBodyElement +from ..config import CodeStyle + + +class SfgFilePrinter: + def __init__(self, code_style: CodeStyle) -> None: + self._code_style = code_style + self._kernel_printer = CAstPrinter( + indent_width=code_style.get_option("indent_width") + ) + + def __call__(self, file: SfgSourceFile) -> str: + code = "" + + if file.file_type == SfgSourceFileType.HEADER: + code += "#pragma once\n\n" + + if file.prelude: + comment = "/**\n" + comment += indent(file.prelude, " * ") + comment += " */\n\n" + + code += comment + + for header in file.includes: + incl = str(header) if header.system_header else f'"{str(header)}"' + code += f"#include {incl}\n" + + if file.includes: + code += "\n" + + # Here begins the actual code + code += "\n\n".join(self.visit(elem) for elem in file.elements) + code += "\n" + + return code + + def visit( + self, elem: SfgNamespaceElement | SfgClassBodyElement, inclass: bool = False + ) -> str: + match elem: + case str(): + return elem + case SfgNamespaceBlock(_, elements, label): + code = f"namespace {label} {{\n" + code += self._code_style.indent( + "\n\n".join(self.visit(e) for e in elements) + ) + code += f"\n}} // namespace {label}" + return code + case SfgEntityDecl(entity): + return self.visit_decl(entity, inclass) + case SfgEntityDef(entity): + return self.visit_defin(entity, inclass) + case SfgClassBody(): + return self.visit_defin(elem, inclass) + case _: + assert False, "illegal code element" + + def visit_decl( + self, + declared_entity: SfgKernelHandle | SfgFunction | SfgClassMember | SfgClass, + inclass: bool = False, + ) -> str: + match declared_entity: + case SfgKernelHandle(kernel): + return self._kernel_printer.print_signature(kernel) + ";" + + case SfgFunction(name, _, params) | SfgMethod(name, _, params): + return self._func_signature(declared_entity, inclass) + ";" + + case SfgConstructor(cls, params): + params_str = ", ".join( + f"{param.dtype.c_string()} {param.name}" for param in params + ) + return f"{cls.name}({params_str});" + + case SfgMemberVariable(name, dtype): + return f"{dtype.c_string()} {name};" + + case SfgClass(kwd, name): + return f"{str(kwd)} {name};" + + case _: + assert False, f"unsupported declared entity: {declared_entity}" + + def visit_defin( + self, + defined_entity: SfgKernelHandle | SfgFunction | SfgClassMember | SfgClassBody, + inclass: bool = False, + ) -> str: + match defined_entity: + case SfgKernelHandle(kernel): + return self._kernel_printer(kernel) + + case SfgFunction(name, tree, params) | SfgMethod(name, tree, params): + sig = self._func_signature(defined_entity, inclass) + body = tree.get_code(self._code_style) + body = "\n{\n" + self._code_style.indent(body) + "\n}" + return sig + body + + case SfgConstructor(cls, params): + params_str = ", ".join( + f"{param.dtype.c_string()} {param.name}" for param in params + ) + + code = "" + if not inclass: + code += f"{cls.name}::" + code += f"{cls.name} ({params_str})" + + inits: list[str] = [] + for var, args in defined_entity.initializers: + args_str = ", ".join(str(arg) for arg in args) + inits.append(f"{str(var)}({args_str})") + + if inits: + code += "\n:" + ",\n".join(inits) + + code += "\n{\n" + self._code_style.indent(defined_entity.body) + "\n}" + return code + + case SfgMemberVariable(name, dtype): + code = dtype.c_string() + if not inclass: + code += f" {defined_entity.owning_class.name}::" + code += f" {name}" + if defined_entity.default_init is not None: + args_str = ", ".join(str(expr) for expr in defined_entity.default_init) + code += "{" + args_str + "}" + code += ";" + return code + + case SfgClassBody(cls, vblocks): + code = f"{cls.class_keyword} {cls.name} {{\n" + vblocks_str = [self._visibility_block(b) for b in vblocks] + code += "\n\n".join(vblocks_str) + code += "\n};\n" + return code + + case _: + assert False, f"unsupported defined entity: {defined_entity}" + + def _visibility_block(self, vblock: SfgVisibilityBlock): + prefix = ( + f"{vblock.visibility}:\n" + if vblock.visibility != SfgVisibility.DEFAULT + else "" + ) + elements = [self.visit(elem, inclass=True) for elem in vblock.elements] + return prefix + self._code_style.indent("\n".join(elements)) + + def _func_signature(self, func: SfgFunction | SfgMethod, inclass: bool): + code = "" + if func.inline: + code += "inline " + code += func.return_type.c_string() + " " + params_str = ", ".join( + f"{param.dtype.c_string()} {param.name}" for param in func.parameters + ) + if isinstance(func, SfgMethod) and not inclass: + code += f"{func.owning_class.name}::" + code += f"{func.name}({params_str})" + + if isinstance(func, SfgMethod) and func.const: + code += " const" + + return code diff --git a/src/pystencilssfg/emission/header_impl_pair.py b/src/pystencilssfg/emission/header_impl_pair.py deleted file mode 100644 index 87ff5f55c1484ffd8e5afd1eb53c0349c7c961e8..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/emission/header_impl_pair.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Sequence -from os import path, makedirs - -from ..context import SfgContext -from .printers import SfgHeaderPrinter, SfgImplPrinter -from .clang_format import invoke_clang_format -from ..config import ClangFormatOptions - -from .emitter import AbstractEmitter, OutputSpec - - -class HeaderImplPairEmitter(AbstractEmitter): - """Emits a header-implementation file pair.""" - - def __init__( - self, - output_spec: OutputSpec, - inline_impl: bool = False, - clang_format: ClangFormatOptions | None = None, - ): - """Create a `HeaderImplPairEmitter` from an [SfgOutputSpec][pystencilssfg.configuration.SfgOutputSpec].""" - self._basename = output_spec.basename - self._output_directory = output_spec.output_directory - self._header_filename = output_spec.get_header_filename() - self._impl_filename = output_spec.get_impl_filename() - self._inline_impl = inline_impl - - self._ospec = output_spec - self._clang_format = clang_format - - @property - def output_files(self) -> Sequence[str]: - """The files that will be written by `write_files`.""" - return ( - path.join(self._output_directory, self._header_filename), - path.join(self._output_directory, self._impl_filename), - ) - - def write_files(self, ctx: SfgContext): - """Write the code represented by the given [SfgContext][pystencilssfg.SfgContext] to the files - specified by the output specification.""" - header_printer = SfgHeaderPrinter(ctx, self._ospec, self._inline_impl) - impl_printer = SfgImplPrinter(ctx, self._ospec, self._inline_impl) - - header = header_printer.get_code() - impl = impl_printer.get_code() - - if self._clang_format is not None: - header = invoke_clang_format(header, self._clang_format) - impl = invoke_clang_format(impl, self._clang_format) - - makedirs(self._output_directory, exist_ok=True) - - with open(self._ospec.get_header_filepath(), "w") as headerfile: - headerfile.write(header) - - with open(self._ospec.get_impl_filepath(), "w") as cppfile: - cppfile.write(impl) diff --git a/src/pystencilssfg/emission/header_only.py b/src/pystencilssfg/emission/header_only.py deleted file mode 100644 index 7d026da7aea8b32495bba6c2298adfb35ff4fb00..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/emission/header_only.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Sequence -from os import path, makedirs - -from ..context import SfgContext -from .printers import SfgHeaderPrinter -from ..config import ClangFormatOptions -from .clang_format import invoke_clang_format - -from .emitter import AbstractEmitter, OutputSpec - - -class HeaderOnlyEmitter(AbstractEmitter): - def __init__( - self, output_spec: OutputSpec, clang_format: ClangFormatOptions | None = None - ): - """Create a `HeaderImplPairEmitter` from an [SfgOutputSpec][pystencilssfg.configuration.SfgOutputSpec].""" - self._basename = output_spec.basename - self._output_directory = output_spec.output_directory - self._header_filename = output_spec.get_header_filename() - - self._ospec = output_spec - self._clang_format = clang_format - - @property - def output_files(self) -> Sequence[str]: - """The files that will be written by `write_files`.""" - return (path.join(self._output_directory, self._header_filename),) - - def write_files(self, ctx: SfgContext): - header_printer = SfgHeaderPrinter(ctx, self._ospec) - header = header_printer.get_code() - if self._clang_format is not None: - header = invoke_clang_format(header, self._clang_format) - - makedirs(self._output_directory, exist_ok=True) - - with open(self._ospec.get_header_filepath(), "w") as headerfile: - headerfile.write(header) diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py deleted file mode 100644 index 9d7c97e7ce732066c91eda3e3cbf887dcb552f77..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/emission/printers.py +++ /dev/null @@ -1,265 +0,0 @@ -from __future__ import annotations - -from textwrap import indent -from itertools import chain, repeat, cycle - -from pystencils.codegen import Kernel -from pystencils.backend.emission import emit_code - -from ..context import SfgContext -from ..visitors import visitor -from ..exceptions import SfgException - -from ..ir.source_components import ( - SfgEmptyLines, - SfgHeaderInclude, - SfgKernelNamespace, - SfgFunction, - SfgClass, - SfgInClassDefinition, - SfgConstructor, - SfgMemberVariable, - SfgMethod, - SfgVisibility, - SfgVisibilityBlock, -) - -from .emitter import OutputSpec - - -def interleave(*iters): - try: - for iter in cycle(iters): - yield next(iter) - except StopIteration: - pass - - -class SfgGeneralPrinter: - @visitor - def visit(self, obj: object) -> str: - raise SfgException(f"Can't print object of type {type(obj)}") - - @visit.case(SfgEmptyLines) - def emptylines(self, el: SfgEmptyLines) -> str: - return "\n" * el.lines - - @visit.case(str) - def string(self, s: str) -> str: - return s - - @visit.case(SfgHeaderInclude) - def include(self, incl: SfgHeaderInclude) -> str: - if incl.system_header: - return f"#include <{incl.file}>" - else: - return f'#include "{incl.file}"' - - def prelude(self, ctx: SfgContext) -> str: - if ctx.prelude_comment: - return ( - "/*\n" - + indent(ctx.prelude_comment, "* ", predicate=lambda _: True) - + "*/\n" - ) - else: - return "" - - def param_list(self, func: SfgFunction) -> str: - params = sorted(list(func.parameters), key=lambda p: p.name) - return ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params) - - -class SfgHeaderPrinter(SfgGeneralPrinter): - def __init__( - self, ctx: SfgContext, output_spec: OutputSpec, inline_impl: bool = False - ): - self._output_spec = output_spec - self._ctx = ctx - self._inline_impl = inline_impl - - def get_code(self) -> str: - return self.visit(self._ctx) - - @visitor - def visit(self, obj: object) -> str: - return super().visit(obj) - - @visit.case(SfgContext) - def frame(self, ctx: SfgContext) -> str: - code = super().prelude(ctx) - - code += "\n#pragma once\n\n" - - includes = filter(lambda incl: not incl.private, ctx.includes()) - code += "\n".join(self.visit(incl) for incl in includes) - code += "\n\n" - - fq_namespace = ctx.fully_qualified_namespace - if fq_namespace is not None: - code += f"namespace {fq_namespace} {{\n\n" - - parts = interleave(ctx.declarations_ordered(), repeat(SfgEmptyLines(1))) - - code += "\n".join(self.visit(p) for p in parts) - - if fq_namespace is not None: - code += f"}} // namespace {fq_namespace}\n" - - if self._inline_impl: - code += f'#include "{self._output_spec.get_impl_filename()}"\n' - - return code - - @visit.case(SfgFunction) - def function(self, func: SfgFunction): - params = sorted(list(func.parameters), key=lambda p: p.name) - param_list = ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params) - return f"{func.return_type.c_string()} {func.name} ( {param_list} );" - - @visit.case(SfgClass) - def sfg_class(self, cls: SfgClass): - code = f"{cls.class_keyword} {cls.class_name} \n" - - if cls.base_classes: - code += f" : {','.join(cls.base_classes)}\n" - - code += "{\n" - - for block in cls.visibility_blocks(): - code += self.visit(block) + "\n" - - code += "};\n" - - return code - - @visit.case(SfgVisibilityBlock) - def vis_block(self, block: SfgVisibilityBlock) -> str: - code = "" - if block.visibility != SfgVisibility.DEFAULT: - code += f"{block.visibility}:\n" - code += self._ctx.codestyle.indent( - "\n".join(self.visit(m) for m in block.members()) - ) - return code - - @visit.case(SfgInClassDefinition) - def sfg_inclassdef(self, definition: SfgInClassDefinition): - return definition.text - - @visit.case(SfgConstructor) - def sfg_constructor(self, constr: SfgConstructor): - code = f"{constr.owning_class.class_name} (" - code += ", ".join(f"{param.dtype.c_string()} {param.name}" for param in constr.parameters) - code += ")\n" - if constr.initializers: - code += " : " + ", ".join(constr.initializers) + "\n" - if constr.body: - code += "{\n" + self._ctx.codestyle.indent(constr.body) + "\n}\n" - else: - code += "{ }\n" - return code - - @visit.case(SfgMemberVariable) - def sfg_member_var(self, var: SfgMemberVariable): - return f"{var.dtype.c_string()} {var.name};" - - @visit.case(SfgMethod) - def sfg_method(self, method: SfgMethod): - code = f"{method.return_type.c_string()} {method.name} ({self.param_list(method)})" - code += "const" if method.const else "" - if method.inline: - code += ( - " {\n" - + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) - + "}\n" - ) - else: - code += ";" - return code - - -class SfgImplPrinter(SfgGeneralPrinter): - def __init__( - self, ctx: SfgContext, output_spec: OutputSpec, inline_impl: bool = False - ): - self._output_spec = output_spec - self._ctx = ctx - self._inline_impl = inline_impl - - def get_code(self) -> str: - return self.visit(self._ctx) - - @visitor - def visit(self, obj: object) -> str: - return super().visit(obj) - - @visit.case(SfgContext) - def frame(self, ctx: SfgContext) -> str: - code = super().prelude(ctx) - - if not self._inline_impl: - code += f'\n#include "{self._output_spec.get_header_filename()}"\n\n' - - includes = filter(lambda incl: incl.private, ctx.includes()) - code += "\n".join(self.visit(incl) for incl in includes) - - code += "\n\n#define FUNC_PREFIX inline\n\n" - - fq_namespace = ctx.fully_qualified_namespace - if fq_namespace is not None: - code += f"namespace {fq_namespace} {{\n\n" - - parts = interleave( - chain( - ctx.kernel_namespaces(), - ctx.functions(), - ctx.classes(), - ), - repeat(SfgEmptyLines(1)), - ) - - code += "\n".join(self.visit(p) for p in parts) - - if fq_namespace is not None: - code += f"}} // namespace {fq_namespace}\n" - - return code - - @visit.case(SfgKernelNamespace) - def kernel_namespace(self, kns: SfgKernelNamespace) -> str: - code = f"namespace {kns.name} {{\n\n" - code += "\n\n".join(self.visit(ast) for ast in kns.kernel_functions) - code += f"\n}} // namespace {kns.name}\n" - return code - - @visit.case(Kernel) - def kernel(self, kfunc: Kernel) -> str: - return emit_code(kfunc) - - @visit.case(SfgFunction) - def function(self, func: SfgFunction) -> str: - inline_prefix = "inline " if self._inline_impl else "" - code = ( - f"{inline_prefix} {func.return_type.c_string()} {func.name} ({self.param_list(func)})" - ) - code += ( - "{\n" + self._ctx.codestyle.indent(func.tree.get_code(self._ctx)) + "}\n" - ) - return code - - @visit.case(SfgClass) - def sfg_class(self, cls: SfgClass) -> str: - methods = filter(lambda m: not m.inline, cls.methods()) - return "\n".join(self.visit(m) for m in methods) - - @visit.case(SfgMethod) - def sfg_method(self, method: SfgMethod) -> str: - inline_prefix = "inline " if self._inline_impl else "" - const_qual = "const" if method.const else "" - code = f"{inline_prefix}{method.return_type} {method.owning_class.class_name}::{method.name}" - code += f"({self.param_list(method)}) {const_qual}" - code += ( - " {\n" + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) + "}\n" - ) - return code diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index 88dbc9be2e215b1fdce5833ef18eac6eab336d74..48f9c08d8754d9a6626109b1c037983a923b33f2 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -8,6 +8,7 @@ from pystencils import Target from pystencilssfg.composer.basic_composer import SequencerArg +from ..config import CodeStyle from ..exceptions import SfgException from ..context import SfgContext from ..composer import ( @@ -17,8 +18,8 @@ from ..composer import ( SfgComposerMixIn, make_sequence, ) -from ..ir.source_components import SfgKernelHandle from ..ir import ( + SfgKernelHandle, SfgCallTreeNode, SfgCallTreeLeaf, SfgKernelCallNode, @@ -93,11 +94,11 @@ class SyclHandler(AugExpr): if isinstance(range, _VarLike): range = asvar(range) - def check_kernel(kernel: SfgKernelHandle): - kfunc = kernel.get_kernel_function() + def check_kernel(khandle: SfgKernelHandle): + kfunc = khandle.kernel if kfunc.target != Target.SYCL: raise SfgException( - f"Kernel given to `parallel_for` is no SYCL kernel: {kernel.kernel_name}" + f"Kernel given to `parallel_for` is no SYCL kernel: {khandle.fqname}" ) id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>") @@ -144,7 +145,7 @@ class SyclGroup(AugExpr): self._ctx = ctx def parallel_for_work_item( - self, range: VarLike | Sequence[int], kernel: SfgKernelHandle + self, range: VarLike | Sequence[int], khandle: SfgKernelHandle ): """Generate a ``parallel_for_work_item` kernel invocation on this group.` @@ -155,10 +156,10 @@ class SyclGroup(AugExpr): if isinstance(range, _VarLike): range = asvar(range) - kfunc = kernel.get_kernel_function() + kfunc = khandle.kernel if kfunc.target != Target.SYCL: raise SfgException( - f"Kernel given to `parallel_for` is no SYCL kernel: {kernel.kernel_name}" + f"Kernel given to `parallel_for` is no SYCL kernel: {khandle.fqname}" ) id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>") @@ -169,13 +170,13 @@ class SyclGroup(AugExpr): and id_regex.search(param.dtype.c_string()) is not None ) - id_param = list(filter(filter_id, kernel.scalar_parameters))[0] + id_param = list(filter(filter_id, khandle.scalar_parameters))[0] h_item = SfgVar("item", PsCustomType("sycl::h_item< 3 >")) comp = SfgComposer(self._ctx) tree = comp.seq( comp.set_param(id_param, AugExpr.format("{}.get_local_id()", h_item)), - SfgKernelCallNode(kernel), + SfgKernelCallNode(khandle), ) kernel_lambda = SfgLambda(("=",), (h_item,), tree, None) @@ -229,11 +230,11 @@ class SfgLambda: def required_parameters(self) -> set[SfgVar]: return self._required_params - def get_code(self, ctx: SfgContext): + def get_code(self, cstyle: CodeStyle): captures = ", ".join(self._captures) params = ", ".join(f"{p.dtype.c_string()} {p.name}" for p in self._params) - body = self._tree.get_code(ctx) - body = ctx.codestyle.indent(body) + body = self._tree.get_code(cstyle) + body = cstyle.indent(body) rtype = ( f"-> {self._return_type.c_string()} " if self._return_type is not None @@ -300,13 +301,13 @@ class SyclKernelInvoke(SfgCallTreeLeaf): def depends(self) -> set[SfgVar]: return self._required_params - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: if isinstance(self._range, SfgVar): range_code = self._range.name else: range_code = "{ " + ", ".join(str(r) for r in self._range) + " }" - kernel_code = self._lambda.get_code(ctx) + kernel_code = self._lambda.get_code(cstyle) invoker = str(self._invoker) method = self._invoke_type.method diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py index eed06b0d684850e462b86aa11b53bc4ca1586185..f3f67a02f4da7a44ae324ddd41f5585645cadb58 100644 --- a/src/pystencilssfg/generator.py +++ b/src/pystencilssfg/generator.py @@ -1,11 +1,17 @@ -import os -from os import path - -from .config import SfgConfig, CommandLineParameters, OutputMode, GLOBAL_NAMESPACE +from pathlib import Path + +from typing import Callable, Any +from .config import ( + SfgConfig, + CommandLineParameters, + OutputMode, + _GlobalNamespace, +) from .context import SfgContext from .composer import SfgComposer -from .emission import AbstractEmitter, OutputSpec +from .emission import SfgCodeEmitter from .exceptions import SfgException +from .lang import HeaderFile class SourceFileGenerator: @@ -27,7 +33,9 @@ class SourceFileGenerator: """ def __init__( - self, sfg_config: SfgConfig | None = None, keep_unknown_argv: bool = False + self, + sfg_config: SfgConfig | None = None, + keep_unknown_argv: bool = False, ): if sfg_config and not isinstance(sfg_config, SfgConfig): raise TypeError("sfg_config is not an SfgConfiguration.") @@ -41,9 +49,9 @@ class SourceFileGenerator: "without a valid entry point, such as a REPL or a multiprocessing fork." ) - scriptpath = __main__.__file__ - scriptname = path.split(scriptpath)[1] - basename = path.splitext(scriptname)[0] + scriptpath = Path(__main__.__file__) + scriptname = scriptpath.name + basename = scriptname.rsplit(".")[0] from argparse import ArgumentParser @@ -67,47 +75,76 @@ class SourceFileGenerator: cli_params.find_conflicts(sfg_config) config.override(sfg_config) + self._output_mode: OutputMode = config.get_option("output_mode") + self._output_dir: Path = config.get_option("output_directory") + + output_files = config._get_output_files(basename) + + from .ir import SfgSourceFile, SfgSourceFileType + + self._header_file = SfgSourceFile( + output_files[0].name, SfgSourceFileType.HEADER + ) + self._impl_file: SfgSourceFile | None + + match self._output_mode: + case OutputMode.HEADER_ONLY: + self._impl_file = None + case OutputMode.STANDALONE: + self._impl_file = SfgSourceFile( + output_files[1].name, SfgSourceFileType.TRANSLATION_UNIT + ) + self._impl_file.includes.append( + HeaderFile.parse(self._header_file.name) + ) + case OutputMode.INLINE: + self._impl_file = SfgSourceFile( + output_files[1].name, SfgSourceFileType.HEADER + ) + + # TODO: Find a way to not hard-code the restrict qualifier in pystencils + self._header_file.elements.append("#define RESTRICT __restrict__") + + outer_namespace: str | _GlobalNamespace = config.get_option("outer_namespace") + + namespace: str | None + if isinstance(outer_namespace, _GlobalNamespace): + namespace = None + else: + namespace = outer_namespace + self._context = SfgContext( - None if config.outer_namespace is GLOBAL_NAMESPACE else config.outer_namespace, # type: ignore + self._header_file, + self._impl_file, + namespace, config.codestyle, argv=script_args, project_info=cli_params.get_project_info(), ) - from .lang import HeaderFile - from .ir import SfgHeaderInclude - - self._context.add_include(SfgHeaderInclude(HeaderFile("cstdint", system_header=True))) - self._context.add_definition("#define RESTRICT __restrict__") + self._emitter = SfgCodeEmitter( + self._output_dir, config.codestyle, config.clang_format + ) - output_mode = config.get_option("output_mode") - output_spec = OutputSpec.create(config, basename) + sort_key = config.codestyle.get_option("includes_sorting_key") + if sort_key is None: - self._emitter: AbstractEmitter - match output_mode: - case OutputMode.HEADER_ONLY: - from .emission import HeaderOnlyEmitter + def default_key(h: HeaderFile): + return str(h) - self._emitter = HeaderOnlyEmitter( - output_spec, clang_format=config.clang_format - ) - case OutputMode.INLINE: - from .emission import HeaderImplPairEmitter + sort_key = default_key - self._emitter = HeaderImplPairEmitter( - output_spec, inline_impl=True, clang_format=config.clang_format - ) - case OutputMode.STANDALONE: - from .emission import HeaderImplPairEmitter - - self._emitter = HeaderImplPairEmitter( - output_spec, clang_format=config.clang_format - ) + self._include_sort_key: Callable[[HeaderFile], Any] = sort_key def clean_files(self): - for file in self._emitter.output_files: - if path.exists(file): - os.remove(file) + header_path = self._output_dir / self._header_file.name + if header_path.exists(): + header_path.unlink() + + if self._impl_file is not None: + impl_path = self._output_dir / self._impl_file.name + if impl_path.exists(): + impl_path.unlink() def __enter__(self) -> SfgComposer: self.clean_files() @@ -115,9 +152,27 @@ class SourceFileGenerator: def __exit__(self, exc_type, exc_value, traceback): if exc_type is None: - # Collect header files for inclusion - from .ir import SfgHeaderInclude, collect_includes - for header in collect_includes(self._context): - self._context.add_include(SfgHeaderInclude(header)) + if self._output_mode == OutputMode.INLINE: + assert self._impl_file is not None + self._header_file.elements.append(f'#include "{self._impl_file.name}"') + + from .ir import collect_includes + + header_includes = collect_includes(self._header_file) + self._header_file.includes = list( + set(self._header_file.includes) | header_includes + ) + self._header_file.includes.sort(key=self._include_sort_key) + + if self._impl_file is not None: + impl_includes = collect_includes(self._impl_file) + # If some header is already included by the generated header file, do not duplicate that inclusion + impl_includes -= header_includes + self._impl_file.includes = list( + set(self._impl_file.includes) | impl_includes + ) + self._impl_file.includes.sort(key=self._include_sort_key) - self._emitter.write_files(self._context) + self._emitter.emit(self._header_file) + if self._impl_file is not None: + self._emitter.emit(self._impl_file) diff --git a/src/pystencilssfg/ir/__init__.py b/src/pystencilssfg/ir/__init__.py index 8eee39cfc2e5fa6e37f677dd780aa35bb9a796b6..8f03fed0d4c2467377cdaab6cf100a13f7ded9fb 100644 --- a/src/pystencilssfg/ir/__init__.py +++ b/src/pystencilssfg/ir/__init__.py @@ -14,23 +14,32 @@ from .call_tree import ( SfgSwitch, ) -from .source_components import ( - SfgHeaderInclude, - SfgEmptyLines, +from .entities import ( + SfgCodeEntity, + SfgNamespace, + SfgGlobalNamespace, SfgKernelNamespace, SfgKernelHandle, - SfgKernelParamVar, SfgFunction, SfgVisibility, SfgClassKeyword, SfgClassMember, - SfgVisibilityBlock, - SfgInClassDefinition, SfgMemberVariable, SfgMethod, SfgConstructor, SfgClass, ) + +from .syntax import ( + SfgEntityDecl, + SfgEntityDef, + SfgVisibilityBlock, + SfgNamespaceBlock, + SfgClassBody, + SfgSourceFileType, + SfgSourceFile, +) + from .analysis import collect_includes __all__ = [ @@ -47,20 +56,25 @@ __all__ = [ "SfgBranch", "SfgSwitchCase", "SfgSwitch", - "SfgHeaderInclude", - "SfgEmptyLines", + "SfgCodeEntity", + "SfgNamespace", + "SfgGlobalNamespace", "SfgKernelNamespace", "SfgKernelHandle", - "SfgKernelParamVar", "SfgFunction", "SfgVisibility", "SfgClassKeyword", "SfgClassMember", - "SfgVisibilityBlock", - "SfgInClassDefinition", "SfgMemberVariable", "SfgMethod", "SfgConstructor", "SfgClass", - "collect_includes" + "SfgEntityDecl", + "SfgEntityDef", + "SfgVisibilityBlock", + "SfgNamespaceBlock", + "SfgClassBody", + "SfgSourceFileType", + "SfgSourceFile", + "collect_includes", ] diff --git a/src/pystencilssfg/ir/analysis.py b/src/pystencilssfg/ir/analysis.py index 0b42594033da16efbde3ab8da3433dfcad03a097..4e43eb92c7a2457c5e60f0743c9ac5d3809f87bc 100644 --- a/src/pystencilssfg/ir/analysis.py +++ b/src/pystencilssfg/ir/analysis.py @@ -1,67 +1,97 @@ from __future__ import annotations -from typing import Any -from functools import reduce - -from ..exceptions import SfgException from ..lang import HeaderFile, includes +from .syntax import ( + SfgSourceFile, + SfgNamespaceElement, + SfgClassBodyElement, + SfgNamespaceBlock, + SfgEntityDecl, + SfgEntityDef, + SfgClassBody, + SfgVisibilityBlock, +) -def collect_includes(obj: Any) -> set[HeaderFile]: - from ..context import SfgContext +def collect_includes(file: SfgSourceFile) -> set[HeaderFile]: from .call_tree import SfgCallTreeNode - from .source_components import ( + from .entities import ( + SfgCodeEntity, + SfgKernelHandle, SfgFunction, - SfgClass, + SfgMethod, + SfgClassMember, SfgConstructor, SfgMemberVariable, - SfgInClassDefinition, ) - match obj: - case SfgContext(): - headers = set() - for func in obj.functions(): - headers |= collect_includes(func) - - for cls in obj.classes(): - headers |= collect_includes(cls) - - return headers - - case SfgCallTreeNode(): - return reduce( - lambda accu, child: accu | collect_includes(child), - obj.children, - obj.required_includes, - ) - - case SfgFunction(_, tree, parameters): - param_headers: set[HeaderFile] = reduce( - set.union, (includes(p) for p in parameters), set() - ) - return param_headers | collect_includes(tree) - - case SfgClass(): - return reduce( - lambda accu, member: accu | (collect_includes(member)), - obj.members(), - set(), - ) - - case SfgConstructor(parameters): - param_headers = reduce( - set.union, (includes(p) for p in parameters), set() - ) - return param_headers - - case SfgMemberVariable(): - return includes(obj) - - case SfgInClassDefinition(): - return set() - - case _: - raise SfgException( - f"Can't collect includes from object of type {type(obj)}" - ) + def visit_decl(entity: SfgCodeEntity | SfgClassMember) -> set[HeaderFile]: + match entity: + case ( + SfgKernelHandle(_, parameters) + | SfgFunction(_, _, parameters) + | SfgMethod(_, _, parameters) + | SfgConstructor(_, parameters, _, _) + ): + incls: set[HeaderFile] = set().union(*(includes(p) for p in parameters)) + if isinstance(entity, (SfgFunction, SfgMethod)): + incls |= includes(entity.return_type) + return incls + + case SfgMemberVariable(): + return includes(entity) + + case _: + assert False, "unexpected entity" + + def walk_syntax( + obj: ( + SfgNamespaceElement + | SfgClassBodyElement + | SfgVisibilityBlock + | SfgCallTreeNode + ), + ) -> set[HeaderFile]: + match obj: + case str(): + return set() + + case SfgCallTreeNode(): + return obj.required_includes.union( + *(walk_syntax(child) for child in obj.children), + ) + + case SfgEntityDecl(entity): + return visit_decl(entity) + + case SfgEntityDef(entity): + match entity: + case SfgKernelHandle(kernel, _): + return ( + set(HeaderFile.parse(h) for h in kernel.required_headers) + | {HeaderFile.parse("<cstdint>")} + | visit_decl(entity) + ) + + case SfgFunction(_, tree, _) | SfgMethod(_, tree, _): + return walk_syntax(tree) | visit_decl(entity) + + case SfgConstructor(): + return visit_decl(entity) + + case SfgMemberVariable(): + return includes(entity) + + case _: + assert False, "unexpected entity" + + case SfgNamespaceBlock(_, elements) | SfgVisibilityBlock(_, elements): + return set().union(*(walk_syntax(elem) for elem in elements)) # type: ignore + + case SfgClassBody(_, vblocks): + return set().union(*(walk_syntax(vb) for vb in vblocks)) + + case _: + assert False, "unexpected syntax element" + + return set().union(*(walk_syntax(elem) for elem in file.elements)) diff --git a/src/pystencilssfg/ir/call_tree.py b/src/pystencilssfg/ir/call_tree.py index a5d2c5a35b1795817305515b74797c2bf3f2b91b..4cee2f526c3a26e34f6b122a3dd1d8a15dd11563 100644 --- a/src/pystencilssfg/ir/call_tree.py +++ b/src/pystencilssfg/ir/call_tree.py @@ -3,11 +3,11 @@ from typing import TYPE_CHECKING, Sequence, Iterable, NewType from abc import ABC, abstractmethod -from .source_components import SfgKernelHandle +from .entities import SfgKernelHandle from ..lang import SfgVar, HeaderFile if TYPE_CHECKING: - from ..context import SfgContext + from ..config import CodeStyle class SfgCallTreeNode(ABC): @@ -35,7 +35,7 @@ class SfgCallTreeNode(ABC): """This node's children""" @abstractmethod - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: """Returns the code of this node. By convention, the code block emitted by this function should not contain a trailing newline. @@ -75,7 +75,7 @@ class SfgEmptyNode(SfgCallTreeLeaf): def __init__(self): super().__init__() - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: return "" @@ -122,7 +122,7 @@ class SfgStatements(SfgCallTreeLeaf): def code_string(self) -> str: return self._code_string - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: return self._code_string @@ -167,8 +167,8 @@ class SfgSequence(SfgCallTreeNode): def __setitem__(self, idx: int, c: SfgCallTreeNode): self._children[idx] = c - def get_code(self, ctx: SfgContext) -> str: - return "\n".join(c.get_code(ctx) for c in self._children) + def get_code(self, cstyle: CodeStyle) -> str: + return "\n".join(c.get_code(cstyle) for c in self._children) class SfgBlock(SfgCallTreeNode): @@ -184,8 +184,8 @@ class SfgBlock(SfgCallTreeNode): def children(self) -> Sequence[SfgCallTreeNode]: return (self._seq,) - def get_code(self, ctx: SfgContext) -> str: - seq_code = ctx.codestyle.indent(self._seq.get_code(ctx)) + def get_code(self, cstyle: CodeStyle) -> str: + seq_code = cstyle.indent(self._seq.get_code(cstyle)) return "{\n" + seq_code + "\n}" @@ -208,9 +208,9 @@ class SfgKernelCallNode(SfgCallTreeLeaf): def depends(self) -> set[SfgVar]: return set(self._kernel_handle.parameters) - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: ast_params = self._kernel_handle.parameters - fnc_name = self._kernel_handle.fully_qualified_name + fnc_name = self._kernel_handle.fqname call_parameters = ", ".join([p.name for p in ast_params]) return f"{fnc_name}({call_parameters});" @@ -228,8 +228,8 @@ class SfgCudaKernelInvocation(SfgCallTreeLeaf): from pystencils import Target from pystencils.codegen import GpuKernel - func = kernel_handle.get_kernel_function() - if not (isinstance(func, GpuKernel) and func.target == Target.CUDA): + kernel = kernel_handle.kernel + if not (isinstance(kernel, GpuKernel) and kernel.target == Target.CUDA): raise ValueError( "An `SfgCudaKernelInvocation` node can only call a CUDA kernel." ) @@ -245,9 +245,9 @@ class SfgCudaKernelInvocation(SfgCallTreeLeaf): def depends(self) -> set[SfgVar]: return set(self._kernel_handle.parameters) | self._depends - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: ast_params = self._kernel_handle.parameters - fnc_name = self._kernel_handle.fully_qualified_name + fnc_name = self._kernel_handle.fqname call_parameters = ", ".join([p.name for p in ast_params]) grid_args = [self._num_blocks, self._threads_per_block] @@ -289,14 +289,14 @@ class SfgBranch(SfgCallTreeNode): self._branch_true, ) + ((self.branch_false,) if self.branch_false is not None else ()) - def get_code(self, ctx: SfgContext) -> str: - code = f"if({self.condition.get_code(ctx)}) {{\n" - code += ctx.codestyle.indent(self.branch_true.get_code(ctx)) + def get_code(self, cstyle: CodeStyle) -> str: + code = f"if({self.condition.get_code(cstyle)}) {{\n" + code += cstyle.indent(self.branch_true.get_code(cstyle)) code += "\n}" if self.branch_false is not None: code += "else {\n" - code += ctx.codestyle.indent(self.branch_false.get_code(ctx)) + code += cstyle.indent(self.branch_false.get_code(cstyle)) code += "\n}" return code @@ -327,13 +327,13 @@ class SfgSwitchCase(SfgCallTreeNode): def is_default(self) -> bool: return self._label == SfgSwitchCase.Default - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: code = "" if self._label == SfgSwitchCase.Default: code += "default: {\n" else: code += f"case {self._label}: {{\n" - code += ctx.codestyle.indent(self.body.get_code(ctx)) + code += cstyle.indent(self.body.get_code(cstyle)) code += "\n}" return code @@ -403,8 +403,8 @@ class SfgSwitch(SfgCallTreeNode): else: self._children[idx] = c - def get_code(self, ctx: SfgContext) -> str: - code = f"switch({self._switch_arg.get_code(ctx)}) {{\n" - code += "\n".join(c.get_code(ctx) for c in self._cases) + def get_code(self, cstyle: CodeStyle) -> str: + code = f"switch({self._switch_arg.get_code(cstyle)}) {{\n" + code += "\n".join(c.get_code(cstyle) for c in self._cases) code += "}" return code diff --git a/src/pystencilssfg/ir/entities.py b/src/pystencilssfg/ir/entities.py new file mode 100644 index 0000000000000000000000000000000000000000..62ae1eb7611065c7b1a12b77d6894143aed341dd --- /dev/null +++ b/src/pystencilssfg/ir/entities.py @@ -0,0 +1,501 @@ +from __future__ import annotations + +from abc import ABC +from enum import Enum, auto +from typing import ( + TYPE_CHECKING, + Sequence, + Generator, +) +from itertools import chain + +from pystencils import Field +from pystencils.codegen import Kernel +from pystencils.types import PsType, PsCustomType + +from ..lang import SfgVar, SfgKernelParamVar, void, ExprLike +from ..exceptions import SfgException + +if TYPE_CHECKING: + from . import SfgCallTreeNode + + +# ========================================================================================================= +# +# SEMANTICAL ENTITIES +# +# These classes model *code entities*, which represent *semantic components* of the generated files. +# +# ========================================================================================================= + + +class SfgCodeEntity: + """Base class for code entities. + + Each code entity has a name and an optional enclosing namespace. + """ + + def __init__(self, name: str, parent_namespace: SfgNamespace) -> None: + self._name = name + self._namespace: SfgNamespace = parent_namespace + + @property + def name(self) -> str: + """Name of this entity""" + return self._name + + @property + def fqname(self) -> str: + """Fully qualified name of this entity""" + if not isinstance(self._namespace, SfgGlobalNamespace): + return self._namespace.fqname + "::" + self._name + else: + return self._name + + @property + def parent_namespace(self) -> SfgNamespace | None: + """Parent namespace of this entity""" + return self._namespace + + +class SfgNamespace(SfgCodeEntity): + """A C++ namespace. + + Each namespace has a `name` and a `parent`; its fully qualified name is given as + ``<parent.name>::<name>``. + + Args: + name: Local name of this namespace + parent: Parent namespace enclosing this namespace + """ + + def __init__(self, name: str, parent_namespace: SfgNamespace) -> None: + super().__init__(name, parent_namespace) + + self._entities: dict[str, SfgCodeEntity] = dict() + + def get_entity(self, qual_name: str) -> SfgCodeEntity | None: + """Find an entity with the given qualified name within this namespace. + + If `qual_name` contains any qualifying delimiters ``::``, + each component but the last is interpreted as a namespace. + """ + tokens = qual_name.split("::", 1) + match tokens: + case [entity_name]: + return self._entities.get(entity_name, None) + case [nspace, remaining_qualname]: + sub_nspace = self._entities.get(nspace, None) + if sub_nspace is not None: + if not isinstance(sub_nspace, SfgNamespace): + raise KeyError( + f"Unable to find entity {qual_name} in namespace {self._name}: " + f"Entity {nspace} is not a namespace." + ) + return sub_nspace.get_entity(remaining_qualname) + else: + return None + case _: + assert False, "unreachable code" + + def add_entity(self, entity: SfgCodeEntity): + if entity.name in self._entities: + raise ValueError( + f"Another entity with the name {entity.fqname} already exists" + ) + self._entities[entity.name] = entity + + def get_child_namespace(self, qual_name: str): + if not qual_name: + raise ValueError("Anonymous namespaces are not supported") + + # Find the namespace by qualified lookup ... + namespace = self.get_entity(qual_name) + if namespace is not None: + if not type(namespace) is SfgNamespace: + raise ValueError(f"Entity {qual_name} exists, but is not a namespace") + else: + # ... or create it + tokens = qual_name.split("::") + namespace = self + for tok in tokens: + namespace = SfgNamespace(tok, namespace) + + return namespace + + +class SfgGlobalNamespace(SfgNamespace): + """The C++ global namespace.""" + + def __init__(self) -> None: + super().__init__("", self) + + @property + def fqname(self) -> str: + return "" + + +class SfgKernelHandle(SfgCodeEntity): + """Handle to a pystencils kernel.""" + + __match_args__ = ("kernel", "parameters") + + def __init__(self, name: str, namespace: SfgKernelNamespace, kernel: Kernel): + super().__init__(name, namespace) + + self._kernel = kernel + self._parameters = [SfgKernelParamVar(p) for p in kernel.parameters] + + self._scalar_params: set[SfgVar] = set() + self._fields: set[Field] = set() + + for param in self._parameters: + if param.wrapped.is_field_parameter: + self._fields |= set(param.wrapped.fields) + else: + self._scalar_params.add(param) + + @property + def parameters(self) -> Sequence[SfgKernelParamVar]: + """Parameters to this kernel""" + return self._parameters + + @property + def scalar_parameters(self) -> set[SfgVar]: + """Scalar parameters to this kernel""" + return self._scalar_params + + @property + def fields(self): + """Fields accessed by this kernel""" + return self._fields + + @property + def kernel(self) -> Kernel: + """Underlying pystencils kernel object""" + return self._kernel + + +class SfgKernelNamespace(SfgNamespace): + """A namespace grouping together a number of kernels.""" + + def __init__(self, name: str, parent: SfgNamespace): + super().__init__(name, parent) + self._kernels: dict[str, SfgKernelHandle] = dict() + + @property + def name(self): + return self._name + + @property + def kernels(self) -> tuple[SfgKernelHandle, ...]: + return tuple(self._kernels.values()) + + def find_kernel(self, name: str) -> SfgKernelHandle | None: + return self._kernels.get(name, None) + + def add_kernel(self, kernel: SfgKernelHandle): + if kernel.name in self._kernels: + raise ValueError( + f"Duplicate kernels: A kernel called {kernel.name} already exists " + f"in namespace {self.fqname}" + ) + self._kernels[kernel.name] = kernel + + +class SfgFunction(SfgCodeEntity): + """A free function.""" + + __match_args__ = ("name", "tree", "parameters", "return_type") + + def __init__( + self, + name: str, + namespace: SfgNamespace, + tree: SfgCallTreeNode, + return_type: PsType = void, + inline: bool = False, + ): + super().__init__(name, namespace) + + self._tree = tree + self._return_type = return_type + self._inline = inline + + self._parameters: tuple[SfgVar, ...] + + from .postprocessing import CallTreePostProcessing + + param_collector = CallTreePostProcessing() + self._parameters = tuple( + sorted(param_collector(self._tree).function_params, key=lambda p: p.name) + ) + + @property + def parameters(self) -> tuple[SfgVar, ...]: + return self._parameters + + @property + def tree(self) -> SfgCallTreeNode: + return self._tree + + @property + def return_type(self) -> PsType: + return self._return_type + + @property + def inline(self) -> bool: + return self._inline + + +class SfgVisibility(Enum): + """Visibility qualifiers of C++""" + + DEFAULT = auto() + PRIVATE = auto() + PROTECTED = auto() + PUBLIC = auto() + + def __str__(self) -> str: + match self: + case SfgVisibility.DEFAULT: + return "" + case SfgVisibility.PRIVATE: + return "private" + case SfgVisibility.PROTECTED: + return "protected" + case SfgVisibility.PUBLIC: + return "public" + + +class SfgClassKeyword(Enum): + """Class keywords of C++""" + + STRUCT = auto() + CLASS = auto() + + def __str__(self) -> str: + match self: + case SfgClassKeyword.STRUCT: + return "struct" + case SfgClassKeyword.CLASS: + return "class" + + +class SfgClassMember(ABC): + """Base class for class member entities""" + + def __init__(self, cls: SfgClass) -> None: + self._cls: SfgClass = cls + self._visibility: SfgVisibility | None = None + + @property + def owning_class(self) -> SfgClass: + if self._cls is None: + raise SfgException(f"{self} is not bound to a class.") + return self._cls + + @property + def visibility(self) -> SfgVisibility: + if self._visibility is None: + raise SfgException( + f"{self} is not bound to a class and therefore has no visibility." + ) + return self._visibility + + +class SfgMemberVariable(SfgVar, SfgClassMember): + """Variable that is a field of a class""" + + def __init__( + self, + name: str, + dtype: PsType, + cls: SfgClass, + default_init: tuple[ExprLike, ...] | None = None, + ): + SfgVar.__init__(self, name, dtype) + SfgClassMember.__init__(self, cls) + self._default_init = default_init + + @property + def default_init(self) -> tuple[ExprLike, ...] | None: + return self._default_init + + +class SfgMethod(SfgClassMember): + """Instance method of a class""" + + __match_args__ = ("name", "tree", "parameters", "return_type") + + def __init__( + self, + name: str, + cls: SfgClass, + tree: SfgCallTreeNode, + return_type: PsType = void, + inline: bool = False, + const: bool = False, + ): + super().__init__(cls) + + self._name = name + self._tree = tree + self._return_type = return_type + self._inline = inline + self._const = const + + self._parameters: tuple[SfgVar, ...] + + from .postprocessing import CallTreePostProcessing + + param_collector = CallTreePostProcessing() + self._parameters = tuple( + sorted(param_collector(self._tree).function_params, key=lambda p: p.name) + ) + + @property + def name(self) -> str: + return self._name + + @property + def parameters(self) -> tuple[SfgVar, ...]: + return self._parameters + + @property + def tree(self) -> SfgCallTreeNode: + return self._tree + + @property + def return_type(self) -> PsType: + return self._return_type + + @property + def inline(self) -> bool: + return self._inline + + @property + def const(self) -> bool: + return self._const + + +class SfgConstructor(SfgClassMember): + """Constructor of a class""" + + __match_args__ = ("owning_class", "parameters", "initializers", "body") + + def __init__( + self, + cls: SfgClass, + parameters: Sequence[SfgVar] = (), + initializers: Sequence[tuple[SfgVar | str, tuple[ExprLike, ...]]] = (), + body: str = "", + ): + super().__init__(cls) + self._parameters = tuple(parameters) + self._initializers = tuple(initializers) + self._body = body + + @property + def parameters(self) -> tuple[SfgVar, ...]: + return self._parameters + + @property + def initializers(self) -> tuple[tuple[SfgVar | str, tuple[ExprLike, ...]], ...]: + return self._initializers + + @property + def body(self) -> str: + return self._body + + +class SfgClass(SfgCodeEntity): + """A C++ class.""" + + __match_args__ = ("class_keyword", "name") + + def __init__( + self, + name: str, + namespace: SfgNamespace, + class_keyword: SfgClassKeyword = SfgClassKeyword.CLASS, + bases: Sequence[str] = (), + ): + if isinstance(bases, str): + raise ValueError("Base classes must be given as a sequence.") + + super().__init__(name, namespace) + + self._class_keyword = class_keyword + self._bases_classes = tuple(bases) + + self._constructors: list[SfgConstructor] = [] + self._methods: list[SfgMethod] = [] + self._member_vars: dict[str, SfgMemberVariable] = dict() + + @property + def src_type(self) -> PsType: + # TODO: Use CppTypeFactory instead + return PsCustomType(self._name) + + @property + def base_classes(self) -> tuple[str, ...]: + return self._bases_classes + + @property + def class_keyword(self) -> SfgClassKeyword: + return self._class_keyword + + def members( + self, visibility: SfgVisibility | None = None + ) -> Generator[SfgClassMember, None, None]: + if visibility is None: + yield from chain( + self._constructors, self._methods, self._member_vars.values() + ) + else: + yield from filter(lambda m: m.visibility == visibility, self.members()) + + def member_variables( + self, visibility: SfgVisibility | None = None + ) -> Generator[SfgMemberVariable, None, None]: + if visibility is not None: + yield from filter( + lambda m: m.visibility == visibility, self._member_vars.values() + ) + else: + yield from self._member_vars.values() + + def constructors( + self, visibility: SfgVisibility | None = None + ) -> Generator[SfgConstructor, None, None]: + if visibility is not None: + yield from filter(lambda m: m.visibility == visibility, self._constructors) + else: + yield from self._constructors + + def methods( + self, visibility: SfgVisibility | None = None + ) -> Generator[SfgMethod, None, None]: + if visibility is not None: + yield from filter(lambda m: m.visibility == visibility, self._methods) + else: + yield from self._methods + + def add_member(self, member: SfgClassMember, vis: SfgVisibility): + if isinstance(member, SfgConstructor): + self._constructors.append(member) + elif isinstance(member, SfgMemberVariable): + self._add_member_variable(member) + elif isinstance(member, SfgMethod): + self._methods.append(member) + else: + raise SfgException(f"{member} is not a valid class member.") + + def _add_member_variable(self, variable: SfgMemberVariable): + if variable.name in self._member_vars: + raise SfgException( + f"Duplicate field name {variable.name} in class {self._name}" + ) + + self._member_vars[variable.name] = variable diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index aa3cd2732f62f5b9b50131b4e1ae1b48aa23e4ce..5563783a5fa14f637e0f16fd0579f266cc9c0c27 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence, Iterable +from typing import Sequence, Iterable import warnings from functools import reduce from dataclasses import dataclass @@ -13,9 +13,10 @@ from pystencils.types import deconstify, PsType from pystencils.codegen.properties import FieldBasePtr, FieldShape, FieldStride from ..exceptions import SfgException +from ..config import CodeStyle from .call_tree import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements -from ..ir.source_components import SfgKernelParamVar +from ..lang.expressions import SfgKernelParamVar from ..lang import ( SfgVar, IFieldExtraction, @@ -27,10 +28,6 @@ from ..lang import ( includes, ) -if TYPE_CHECKING: - from ..context import SfgContext - from .source_components import SfgClass - class FlattenSequences: """Flattens any nested sequences occuring in a kernel call tree.""" @@ -65,19 +62,9 @@ class FlattenSequences: class PostProcessingContext: - def __init__(self, enclosing_class: SfgClass | None = None) -> None: - self.enclosing_class: SfgClass | None = enclosing_class + def __init__(self) -> None: self._live_variables: dict[str, SfgVar] = dict() - def is_method(self) -> bool: - return self.enclosing_class is not None - - def get_enclosing_class(self) -> SfgClass: - if self.enclosing_class is None: - raise SfgException("Cannot get the enclosing class of a free function.") - - return self.enclosing_class - @property def live_variables(self) -> set[SfgVar]: return set(self._live_variables.values()) @@ -144,8 +131,7 @@ class PostProcessingResult: class CallTreePostProcessing: - def __init__(self, enclosing_class: SfgClass | None = None): - self._enclosing_class = enclosing_class + def __init__(self): self._flattener = FlattenSequences() def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult: @@ -174,7 +160,7 @@ class CallTreePostProcessing: def get_live_variables(self, node: SfgCallTreeNode) -> set[SfgVar]: match node: case SfgSequence(): - ppc = self._ppc() + ppc = PostProcessingContext() self.handle_sequence(node, ppc) return ppc.live_variables @@ -191,9 +177,6 @@ class CallTreePostProcessing: set(), ) - def _ppc(self) -> PostProcessingContext: - return PostProcessingContext(enclosing_class=self._enclosing_class) - class SfgDeferredNode(SfgCallTreeNode, ABC): """Nodes of this type are inserted as placeholders into the kernel call tree @@ -213,7 +196,7 @@ class SfgDeferredNode(SfgCallTreeNode, ABC): def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: pass - def get_code(self, ctx: SfgContext) -> str: + def get_code(self, cstyle: CodeStyle) -> str: raise SfgException( "Invalid access into deferred node; deferred nodes must be expanded first." ) diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py deleted file mode 100644 index ea43ac8e06cd7520c75eb266c8ff9008ca7132a0..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/ir/source_components.py +++ /dev/null @@ -1,572 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from enum import Enum, auto -from typing import TYPE_CHECKING, Sequence, Generator, TypeVar -from dataclasses import replace -from itertools import chain - -from pystencils import CreateKernelConfig, create_kernel, Field -from pystencils.codegen import Kernel, Parameter -from pystencils.types import PsType, PsCustomType - -from ..lang import SfgVar, HeaderFile, void -from ..exceptions import SfgException - -if TYPE_CHECKING: - from . import SfgCallTreeNode - from ..context import SfgContext - - -class SfgEmptyLines: - def __init__(self, lines: int): - self._lines = lines - - @property - def lines(self) -> int: - return self._lines - - -class SfgHeaderInclude: - """Represent ``#include``-directives.""" - - def __init__( - self, header_file: HeaderFile, private: bool = False - ): - self._header_file = header_file - self._private = private - - @property - def file(self) -> str: - return self._header_file.filepath - - @property - def system_header(self): - return self._header_file.system_header - - @property - def private(self): - return self._private - - def __hash__(self) -> int: - return hash((self._header_file, self._private)) - - def __eq__(self, other: object) -> bool: - return ( - isinstance(other, SfgHeaderInclude) - and self._header_file == other._header_file - and self._private == other._private - ) - - -class SfgKernelNamespace: - """A namespace grouping together a number of kernels.""" - - def __init__(self, ctx: SfgContext, name: str): - self._ctx = ctx - self._name = name - self._kernel_functions: dict[str, Kernel] = dict() - - @property - def name(self): - return self._name - - @property - def kernel_functions(self): - yield from self._kernel_functions.values() - - def get_kernel_function(self, khandle: SfgKernelHandle) -> Kernel: - if khandle.kernel_namespace is not self: - raise ValueError( - f"Kernel handle does not belong to this namespace: {khandle}" - ) - - return self._kernel_functions[khandle.kernel_name] - - def add(self, kernel: Kernel, name: str | None = None): - """Adds an existing pystencils AST to this namespace. - If a name is specified, the AST's function name is changed.""" - if name is not None: - astname = name - else: - astname = kernel.name - - if astname in self._kernel_functions: - raise ValueError( - f"Duplicate ASTs: An AST with name {astname} already exists in namespace {self._name}" - ) - - if name is not None: - kernel.name = name - - self._kernel_functions[astname] = kernel - - for header in kernel.required_headers: - self._ctx.add_include(SfgHeaderInclude(HeaderFile.parse(header), private=True)) - - return SfgKernelHandle(self._ctx, astname, self, kernel.parameters) - - def create( - self, - assignments, - name: str | None = None, - config: CreateKernelConfig | None = None, - ): - """Creates a new pystencils kernel from a list of assignments and a configuration. - This is a wrapper around `pystencils.create_kernel` - with a subsequent call to `add`. - """ - if config is None: - config = CreateKernelConfig() - - if name is not None: - if name in self._kernel_functions: - raise ValueError( - f"Duplicate ASTs: An AST with name {name} already exists in namespace {self._name}" - ) - config = replace(config, function_name=name) - - # type: ignore - ast = create_kernel(assignments, config=config) - return self.add(ast) - - -class SfgKernelHandle: - """A handle that represents a pystencils kernel within a kernel namespace.""" - - def __init__( - self, - ctx: SfgContext, - name: str, - namespace: SfgKernelNamespace, - parameters: Sequence[Parameter], - ): - self._ctx = ctx - self._name = name - self._namespace = namespace - self._parameters = [SfgKernelParamVar(p) for p in parameters] - - self._scalar_params: set[SfgVar] = set() - self._fields: set[Field] = set() - - for param in self._parameters: - if param.wrapped.is_field_parameter: - self._fields |= set(param.wrapped.fields) - else: - self._scalar_params.add(param) - - @property - def kernel_name(self): - return self._name - - @property - def kernel_namespace(self): - return self._namespace - - @property - def fully_qualified_name(self): - match self._ctx.fully_qualified_namespace: - case None: - return f"{self.kernel_namespace.name}::{self.kernel_name}" - case fqn: - return f"{fqn}::{self.kernel_namespace.name}::{self.kernel_name}" - - @property - def parameters(self) -> Sequence[SfgKernelParamVar]: - return self._parameters - - @property - def scalar_parameters(self) -> set[SfgVar]: - return self._scalar_params - - @property - def fields(self): - return self._fields - - def get_kernel_function(self) -> Kernel: - return self._namespace.get_kernel_function(self) - - -SymbolLike_T = TypeVar("SymbolLike_T", bound=Parameter) - - -class SfgKernelParamVar(SfgVar): - __match_args__ = ("wrapped",) - - """Cast pystencils- or SymPy-native symbol-like objects as a `SfgVar`.""" - - def __init__(self, param: Parameter): - self._param = param - super().__init__(param.name, param.dtype) - - @property - def wrapped(self) -> Parameter: - return self._param - - def _args(self): - return (self._param,) - - -class SfgFunction: - __match_args__ = ("name", "tree", "parameters") - - def __init__( - self, - name: str, - tree: SfgCallTreeNode, - return_type: PsType = void, - _is_method: bool = False, - ): - self._name = name - self._tree = tree - self._return_type = return_type - - self._parameters: set[SfgVar] - if not _is_method: - from .postprocessing import CallTreePostProcessing - - param_collector = CallTreePostProcessing() - self._parameters = param_collector(self._tree).function_params - - @property - def name(self) -> str: - return self._name - - @property - def parameters(self) -> set[SfgVar]: - return self._parameters - - @property - def tree(self) -> SfgCallTreeNode: - return self._tree - - @property - def return_type(self) -> PsType: - return self._return_type - - -class SfgVisibility(Enum): - DEFAULT = auto() - PRIVATE = auto() - PROTECTED = auto() - PUBLIC = auto() - - def __str__(self) -> str: - match self: - case SfgVisibility.DEFAULT: - return "" - case SfgVisibility.PRIVATE: - return "private" - case SfgVisibility.PROTECTED: - return "protected" - case SfgVisibility.PUBLIC: - return "public" - - -class SfgClassKeyword(Enum): - STRUCT = auto() - CLASS = auto() - - def __str__(self) -> str: - match self: - case SfgClassKeyword.STRUCT: - return "struct" - case SfgClassKeyword.CLASS: - return "class" - - -class SfgClassMember(ABC): - def __init__(self) -> None: - self._cls: SfgClass | None = None - self._visibility: SfgVisibility | None = None - - @property - def owning_class(self) -> SfgClass: - if self._cls is None: - raise SfgException(f"{self} is not bound to a class.") - return self._cls - - @property - def visibility(self) -> SfgVisibility: - if self._visibility is None: - raise SfgException( - f"{self} is not bound to a class and therefore has no visibility." - ) - return self._visibility - - @property - def is_bound(self) -> bool: - return self._cls is not None - - def _bind(self, cls: SfgClass, vis: SfgVisibility): - if self.is_bound: - raise SfgException( - f"Binding {self} to class {cls.class_name} failed: " - f"{self} was already bound to {self.owning_class.class_name}" - ) - self._cls = cls - self._vis = vis - - -class SfgVisibilityBlock: - def __init__(self, visibility: SfgVisibility) -> None: - self._vis = visibility - self._members: list[SfgClassMember] = [] - self._cls: SfgClass | None = None - - @property - def visibility(self) -> SfgVisibility: - return self._vis - - def append_member(self, member: SfgClassMember): - if self._cls is not None: - self._cls._add_member(member, self._vis) - self._members.append(member) - - def members(self) -> Generator[SfgClassMember, None, None]: - yield from self._members - - @property - def is_bound(self) -> bool: - return self._cls is not None - - def _bind(self, cls: SfgClass): - if self._cls is not None: - raise SfgException( - f"Binding visibility block to class {cls.class_name} failed: " - f"was already bound to {self._cls.class_name}" - ) - self._cls = cls - - -class SfgInClassDefinition(SfgClassMember): - def __init__(self, text: str): - SfgClassMember.__init__(self) - self._text = text - - @property - def text(self) -> str: - return self._text - - def __str__(self) -> str: - return self._text - - -class SfgMemberVariable(SfgVar, SfgClassMember): - def __init__(self, name: str, dtype: PsType): - SfgVar.__init__(self, name, dtype) - SfgClassMember.__init__(self) - - -class SfgMethod(SfgFunction, SfgClassMember): - def __init__( - self, - name: str, - tree: SfgCallTreeNode, - return_type: PsType = PsCustomType("void"), - inline: bool = False, - const: bool = False, - ): - SfgFunction.__init__(self, name, tree, return_type=return_type, _is_method=True) - SfgClassMember.__init__(self) - - self._inline = inline - self._const = const - self._parameters: set[SfgVar] = set() - - @property - def inline(self) -> bool: - return self._inline - - @property - def const(self) -> bool: - return self._const - - def _bind(self, cls: SfgClass, vis: SfgVisibility): - super()._bind(cls, vis) - - from .postprocessing import CallTreePostProcessing - - param_collector = CallTreePostProcessing(enclosing_class=cls) - self._parameters = param_collector(self._tree).function_params - - -class SfgConstructor(SfgClassMember): - __match_args__ = ("parameters", "initializers", "body") - - def __init__( - self, - parameters: Sequence[SfgVar] = (), - initializers: Sequence[str] = (), - body: str = "", - ): - SfgClassMember.__init__(self) - self._parameters = tuple(parameters) - self._initializers = tuple(initializers) - self._body = body - - @property - def parameters(self) -> tuple[SfgVar, ...]: - return self._parameters - - @property - def initializers(self) -> tuple[str, ...]: - return self._initializers - - @property - def body(self) -> str: - return self._body - - -class SfgClass: - """Models a C++ class. - - ### Adding members to classes - - Members are never added directly to a class. Instead, they are added to - an [SfgVisibilityBlock][pystencilssfg.source_components.SfgVisibilityBlock] - which defines their syntactic position and visibility modifier in the code. - At the top of every class, there is a default visibility block - accessible through the `default` property. - To add members with custom visibility, create a new SfgVisibilityBlock, - add members to the block, and add the block using `append_visibility_block`. - - A more succinct interface for constructing classes is available through the - [SfgClassComposer][pystencilssfg.composer.SfgClassComposer]. - """ - - __match_args__ = ("class_name",) - - def __init__( - self, - class_name: str, - class_keyword: SfgClassKeyword = SfgClassKeyword.CLASS, - bases: Sequence[str] = (), - ): - if isinstance(bases, str): - raise ValueError("Base classes must be given as a sequence.") - - self._class_name = class_name - self._class_keyword = class_keyword - self._bases_classes = tuple(bases) - - self._default_block = SfgVisibilityBlock(SfgVisibility.DEFAULT) - self._default_block._bind(self) - self._blocks = [self._default_block] - - self._definitions: list[SfgInClassDefinition] = [] - self._constructors: list[SfgConstructor] = [] - self._methods: list[SfgMethod] = [] - self._member_vars: dict[str, SfgMemberVariable] = dict() - - @property - def class_name(self) -> str: - return self._class_name - - @property - def src_type(self) -> PsType: - return PsCustomType(self._class_name) - - @property - def base_classes(self) -> tuple[str, ...]: - return self._bases_classes - - @property - def class_keyword(self) -> SfgClassKeyword: - return self._class_keyword - - @property - def default(self) -> SfgVisibilityBlock: - return self._default_block - - def append_visibility_block(self, block: SfgVisibilityBlock): - if block.visibility == SfgVisibility.DEFAULT: - raise SfgException( - "Can't add another block with DEFAULT visibility to a class. Use `.default` instead." - ) - - block._bind(self) - for m in block.members(): - self._add_member(m, block.visibility) - self._blocks.append(block) - - def visibility_blocks(self) -> Generator[SfgVisibilityBlock, None, None]: - yield from self._blocks - - def members( - self, visibility: SfgVisibility | None = None - ) -> Generator[SfgClassMember, None, None]: - if visibility is None: - yield from chain.from_iterable(b.members() for b in self._blocks) - else: - yield from chain.from_iterable( - b.members() - for b in filter(lambda b: b.visibility == visibility, self._blocks) - ) - - def definitions( - self, visibility: SfgVisibility | None = None - ) -> Generator[SfgInClassDefinition, None, None]: - if visibility is not None: - yield from filter(lambda m: m.visibility == visibility, self._definitions) - else: - yield from self._definitions - - def member_variables( - self, visibility: SfgVisibility | None = None - ) -> Generator[SfgMemberVariable, None, None]: - if visibility is not None: - yield from filter( - lambda m: m.visibility == visibility, self._member_vars.values() - ) - else: - yield from self._member_vars.values() - - def constructors( - self, visibility: SfgVisibility | None = None - ) -> Generator[SfgConstructor, None, None]: - if visibility is not None: - yield from filter(lambda m: m.visibility == visibility, self._constructors) - else: - yield from self._constructors - - def methods( - self, visibility: SfgVisibility | None = None - ) -> Generator[SfgMethod, None, None]: - if visibility is not None: - yield from filter(lambda m: m.visibility == visibility, self._methods) - else: - yield from self._methods - - # PRIVATE - - def _add_member(self, member: SfgClassMember, vis: SfgVisibility): - if isinstance(member, SfgConstructor): - self._add_constructor(member) - elif isinstance(member, SfgMemberVariable): - self._add_member_variable(member) - elif isinstance(member, SfgMethod): - self._add_method(member) - elif isinstance(member, SfgInClassDefinition): - self._add_definition(member) - else: - raise SfgException(f"{member} is not a valid class member.") - - member._bind(self, vis) - - def _add_definition(self, definition: SfgInClassDefinition): - self._definitions.append(definition) - - def _add_constructor(self, constr: SfgConstructor): - self._constructors.append(constr) - - def _add_method(self, method: SfgMethod): - self._methods.append(method) - - def _add_member_variable(self, variable: SfgMemberVariable): - if variable.name in self._member_vars: - raise SfgException( - f"Duplicate field name {variable.name} in class {self._class_name}" - ) - - self._member_vars[variable.name] = variable diff --git a/src/pystencilssfg/ir/syntax.py b/src/pystencilssfg/ir/syntax.py new file mode 100644 index 0000000000000000000000000000000000000000..cdbd4c283b6bb0078e1051f89565b3b6b32d8d21 --- /dev/null +++ b/src/pystencilssfg/ir/syntax.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +from enum import Enum, auto +from typing import ( + Iterable, + TypeVar, + Generic, +) + +from ..lang import HeaderFile + +from .entities import ( + SfgNamespace, + SfgKernelHandle, + SfgFunction, + SfgClassMember, + SfgVisibility, + SfgClass, +) + +# ========================================================================================================= +# +# SYNTACTICAL ELEMENTS +# +# These classes model *code elements*, which represent the actual syntax objects that populate the output +# files, their namespaces and class bodies. +# +# ========================================================================================================= + + +SourceEntity_T = TypeVar( + "SourceEntity_T", + bound=SfgKernelHandle | SfgFunction | SfgClassMember | SfgClass, + covariant=True, +) +"""Source entities that may have declarations and definitions.""" + + +class SfgEntityDecl(Generic[SourceEntity_T]): + """Declaration of a function, class, method, or constructor""" + + __match_args__ = ("entity",) + + def __init__(self, entity: SourceEntity_T) -> None: + self._entity = entity + + @property + def entity(self) -> SourceEntity_T: + return self._entity + + +class SfgEntityDef(Generic[SourceEntity_T]): + """Definition of a function, class, method, or constructor""" + + __match_args__ = ("entity",) + + def __init__(self, entity: SourceEntity_T) -> None: + self._entity = entity + + @property + def entity(self) -> SourceEntity_T: + return self._entity + + +SfgClassBodyElement = str | SfgEntityDecl[SfgClassMember] | SfgEntityDef[SfgClassMember] +"""Elements that may be placed in the visibility blocks of a class body.""" + + +class SfgVisibilityBlock: + """Visibility-qualified block inside a class definition body. + + Visibility blocks host the code elements placed inside a class body: + method and constructor declarations, + in-class method and constructor definitions, + as well as variable declarations and definitions. + + Args: + visibility: The visibility qualifier of this block + """ + + __match_args__ = ("visibility", "elements") + + def __init__(self, visibility: SfgVisibility) -> None: + self._vis = visibility + self._elements: list[SfgClassBodyElement] = [] + self._cls: SfgClass | None = None + + @property + def visibility(self) -> SfgVisibility: + return self._vis + + @property + def elements(self) -> list[SfgClassBodyElement]: + return self._elements + + @elements.setter + def elements(self, elems: Iterable[SfgClassBodyElement]): + self._elements = list(elems) + + +class SfgNamespaceBlock: + """A C++ namespace block. + + Args: + namespace: Namespace associated with this block + label: Label printed at the opening brace of this block. + This may be the namespace name, or a compressed qualified + name containing one or more of its parent namespaces. + """ + + __match_args__ = ( + "namespace", + "elements", + "label", + ) + + def __init__(self, namespace: SfgNamespace, label: str | None = None) -> None: + self._namespace = namespace + self._label = label if label is not None else namespace.name + self._elements: list[SfgNamespaceElement] = [] + + @property + def namespace(self) -> SfgNamespace: + return self._namespace + + @property + def label(self) -> str: + return self._label + + @property + def elements(self) -> list[SfgNamespaceElement]: + """Sequence of source elements that make up the body of this namespace""" + return self._elements + + @elements.setter + def elements(self, elems: Iterable[SfgNamespaceElement]): + self._elements = list(elems) + + +class SfgClassBody: + """Body of a class definition.""" + + __match_args__ = ("associated_class", "visibility_blocks") + + def __init__( + self, + cls: SfgClass, + default_block: SfgVisibilityBlock, + vis_blocks: Iterable[SfgVisibilityBlock], + ) -> None: + self._cls = cls + assert default_block.visibility == SfgVisibility.DEFAULT + self._default_block = default_block + self._blocks = [self._default_block] + list(vis_blocks) + + @property + def associated_class(self) -> SfgClass: + return self._cls + + @property + def default(self) -> SfgVisibilityBlock: + return self._default_block + + def append_visibility_block(self, block: SfgVisibilityBlock): + if block.visibility == SfgVisibility.DEFAULT: + raise ValueError( + "Can't add another block with DEFAULT visibility to this class body." + ) + self._blocks.append(block) + + @property + def visibility_blocks(self) -> tuple[SfgVisibilityBlock, ...]: + return tuple(self._blocks) + + +SfgNamespaceElement = ( + str | SfgNamespaceBlock | SfgClassBody | SfgEntityDecl | SfgEntityDef +) +"""Elements that may be placed inside a namespace, including the global namespace.""" + + +class SfgSourceFileType(Enum): + HEADER = auto() + TRANSLATION_UNIT = auto() + + +class SfgSourceFile: + """A C++ source file. + + Args: + name: Name of the file (without parent directories), e.g. ``Algorithms.cpp`` + file_type: Type of the source file (header or translation unit) + prelude: Optionally, text of the prelude comment printed at the top of the file + """ + + def __init__( + self, name: str, file_type: SfgSourceFileType, prelude: str | None = None + ) -> None: + self._name: str = name + self._file_type: SfgSourceFileType = file_type + self._prelude: str | None = prelude + self._includes: list[HeaderFile] = [] + self._elements: list[SfgNamespaceElement] = [] + + @property + def name(self) -> str: + """Name of this source file""" + return self._name + + @property + def file_type(self) -> SfgSourceFileType: + """File type of this source file""" + return self._file_type + + @property + def prelude(self) -> str | None: + """Text of the prelude comment""" + return self._prelude + + @prelude.setter + def prelude(self, text: str | None): + self._prelude = text + + @property + def includes(self) -> list[HeaderFile]: + """Sequence of header files to be included at the top of this file""" + return self._includes + + @includes.setter + def includes(self, incl: Iterable[HeaderFile]): + self._includes = list(incl) + + @property + def elements(self) -> list[SfgNamespaceElement]: + """Sequence of source elements comprising the body of this file""" + return self._elements + + @elements.setter + def elements(self, elems: Iterable[SfgNamespaceElement]): + self._elements = list(elems) diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index 9218ec2b7d7f94517e35a2c9a8e4e4ddaa7c3a2a..a8de86be10ce44c2ac2d49cc3b5fba0e1549de50 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -2,6 +2,7 @@ from .headers import HeaderFile from .expressions import ( SfgVar, + SfgKernelParamVar, AugExpr, VarLike, _VarLike, @@ -21,6 +22,7 @@ from .types import cpptype, void, Ref, strip_ptr_ref __all__ = [ "HeaderFile", "SfgVar", + "SfgKernelParamVar", "AugExpr", "VarLike", "_VarLike", diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index f86140ee7a3775caab19f69c34ef97822975b95e..72287eaad73afc08770d451c0847c362ef561519 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -6,7 +6,8 @@ from abc import ABC, abstractmethod import sympy as sp from pystencils import TypedSymbol -from pystencils.types import PsType, UserTypeSpec, create_type +from pystencils.codegen import Parameter +from pystencils.types import PsType, PsIntegerType, UserTypeSpec, create_type from ..exceptions import SfgException from .headers import HeaderFile @@ -74,6 +75,23 @@ class SfgVar: return self.name_and_type() +class SfgKernelParamVar(SfgVar): + __match_args__ = ("wrapped",) + + """Cast pystencils- or SymPy-native symbol-like objects as a `SfgVar`.""" + + def __init__(self, param: Parameter): + self._param = param + super().__init__(param.name, param.dtype) + + @property + def wrapped(self) -> Parameter: + return self._param + + def _args(self): + return (self._param,) + + class DependentExpression: """Wrapper around a C++ expression code string, annotated with a set of variables and a set of header files this expression depends on. @@ -434,7 +452,7 @@ def depends(expr: ExprLike) -> set[SfgVar]: raise ValueError(f"Invalid expression: {expr}") -def includes(expr: ExprLike) -> set[HeaderFile]: +def includes(obj: ExprLike | PsType) -> set[HeaderFile]: """Determine the set of header files an expression depends on. Args: @@ -447,21 +465,33 @@ def includes(expr: ExprLike) -> set[HeaderFile]: ValueError: If the argument was not a valid variable or expression """ - match expr: + if isinstance(obj, PsType): + obj = strip_ptr_ref(obj) + + match obj: + case CppType(): + return set(obj.includes) + + case PsType(): + headers = set(HeaderFile.parse(h) for h in obj.required_headers) + if isinstance(obj, PsIntegerType): + headers.add(HeaderFile.parse("<cstdint>")) + return headers + case SfgVar(_, dtype): - match dtype: - case CppType(): - return set(dtype.includes) - case _: - return set(HeaderFile.parse(h) for h in dtype.required_headers) + return includes(dtype) + case TypedSymbol(): - return includes(asvar(expr)) + return includes(asvar(obj)) + case str(): return set() + case AugExpr(): - return set(expr.includes) + return set(obj.includes) + case _: - raise ValueError(f"Invalid expression: {expr}") + raise ValueError(f"Invalid expression: {obj}") class IFieldExtraction(ABC): diff --git a/src/pystencilssfg/visitors/__init__.py b/src/pystencilssfg/visitors/__init__.py deleted file mode 100644 index fc7af1b6363bdb1167ee6c0e164ba87e31fbb6b0..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/visitors/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .dispatcher import visitor - -__all__ = [ - "visitor", -] diff --git a/src/pystencilssfg/visitors/dispatcher.py b/src/pystencilssfg/visitors/dispatcher.py deleted file mode 100644 index 85a0f087bbf5c1ffcfb99d3ec3acbdbda77089ab..0000000000000000000000000000000000000000 --- a/src/pystencilssfg/visitors/dispatcher.py +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import annotations -from typing import Callable, TypeVar, Generic -from types import MethodType - -from functools import wraps - -V = TypeVar("V") -R = TypeVar("R") - - -class VisitorDispatcher(Generic[V, R]): - def __init__(self, wrapped_method: Callable[..., R]): - self._dispatch_dict: dict[type, Callable[..., R]] = {} - self._wrapped_method: Callable[..., R] = wrapped_method - - def case(self, node_type: type): - """Decorator for visitor's case handlers.""" - - def decorate(handler: Callable[..., R]): - if node_type in self._dispatch_dict: - raise ValueError(f"Duplicate visitor case {node_type}") - self._dispatch_dict[node_type] = handler - return handler - - return decorate - - def __call__(self, instance: V, node: object, *args, **kwargs) -> R: - for cls in node.__class__.mro(): - if cls in self._dispatch_dict: - return self._dispatch_dict[cls](instance, node, *args, **kwargs) - - return self._wrapped_method(instance, node, *args, **kwargs) - - def __get__(self, obj: V, objtype=None) -> Callable[..., R]: - if obj is None: - return self - return MethodType(self, obj) - - -def visitor(method): - """Decorator to create a visitor using type-based dispatch. - - Use this decorator to convert a method into a visitor, like shown below. - After declaring a method (e.g. `my_method`) a visitor, - its case handlers can be declared using the `my_method.case` decorator, like this: - - ```Python - class DemoVisitor: - @visitor - def visit(self, obj: object): - # fallback case - ... - - @visit.case(str) - def visit_str(self, obj: str): - # code for handling a str - ``` - - When `visit` is later called with some object `x`, the case handler to be executed is - determined according to the method resolution order of `x` (i.e. along its type's inheritance hierarchy). - If no case matches, the fallback code in the original visitor method is executed. - In this example, if `visit` is called with an object of type `str`, the call is dispatched to `visit_str`. - - This visitor dispatch method is primarily designed for traversing abstract syntax tree structures. - The primary visitor method (`visit` in above example) should define the common parent type of all object - types the visitor can handle, with cases declared for all required subtypes. - However, this type relationship is not enforced at runtime. - """ - return wraps(method)(VisitorDispatcher(method)) diff --git a/tests/extensions/test_sycl.py b/tests/extensions/test_sycl.py index db99278c3a00a333400a7a18882163676c389d00..71effb60a3f10a3d5f505b01b1ef256cea9ad45e 100644 --- a/tests/extensions/test_sycl.py +++ b/tests/extensions/test_sycl.py @@ -1,12 +1,10 @@ import pytest -from pystencilssfg import SourceFileGenerator import pystencilssfg.extensions.sycl as sycl import pystencils as ps -from pystencilssfg import SfgContext -def test_parallel_for_1_kernels(): - sfg = sycl.SyclComposer(SfgContext()) +def test_parallel_for_1_kernels(sfg): + sfg = sycl.SyclComposer(sfg) data_type = "double" dim = 2 f, g, h, i = ps.fields(f"f,g,h,i:{data_type}[{dim}D]") @@ -24,8 +22,8 @@ def test_parallel_for_1_kernels(): ) -def test_parallel_for_2_kernels(): - sfg = sycl.SyclComposer(SfgContext()) +def test_parallel_for_2_kernels(sfg): + sfg = sycl.SyclComposer(sfg) data_type = "double" dim = 2 f, g, h, i = ps.fields(f"f,g,h,i:{data_type}[{dim}D]") @@ -43,8 +41,8 @@ def test_parallel_for_2_kernels(): ) -def test_parallel_for_2_kernels_fail(): - sfg = sycl.SyclComposer(SfgContext()) +def test_parallel_for_2_kernels_fail(sfg): + sfg = sycl.SyclComposer(sfg) data_type = "double" dim = 2 f, g = ps.fields(f"f,g:{data_type}[{dim}D]") diff --git a/tests/generator/test_config.py b/tests/generator/test_config.py index 4485dc22e639b185b8b0756ae6d69f92af5e45e8..250c158c633d6f8fc2f11f0d4b3b2cbd13a13128 100644 --- a/tests/generator/test_config.py +++ b/tests/generator/test_config.py @@ -1,4 +1,5 @@ import pytest +from pathlib import Path from pystencilssfg.config import ( SfgConfig, @@ -86,7 +87,7 @@ def test_from_commandline(sample_config_module): cli_args = CommandLineParameters(args) cfg = cli_args.get_config() - assert cfg.output_directory == ".out" + assert cfg.output_directory == Path(".out") assert cfg.extensions.header == "h++" assert cfg.extensions.impl == "c++" @@ -100,7 +101,7 @@ def test_from_commandline(sample_config_module): assert cfg.clang_format.code_style == "llvm" assert cfg.clang_format.skip is True assert ( - cfg.output_directory == "gen_sources" + cfg.output_directory == Path("gen_sources") ) # value from config module overridden by commandline assert cfg.outer_namespace == "myproject" assert cfg.extensions.header == "hpp" diff --git a/tests/generator_scripts/index.yaml b/tests/generator_scripts/index.yaml index c78b335af26d9a175893499345ec29a20d30dd8c..0e08e228702bf5c762120e9bc66117a5892bf66d 100644 --- a/tests/generator_scripts/index.yaml +++ b/tests/generator_scripts/index.yaml @@ -17,6 +17,17 @@ TestIllegalArgs: extra-args: [--sfg-file-extensionss, ".c++,.h++"] expect-failure: true +TestIncludeSorting: + sfg-args: + output-mode: header-only + expect-code: + hpp: + - regex: >- + #include\s\<memory>\s* + #include\s<vector>\s* + #include\s<array> + strip-whitespace: true + # Basic Composer Functionality BasicDefinitions: @@ -25,7 +36,7 @@ BasicDefinitions: expect-code: hpp: - regex: >- - #include\s\"config\.h\"\s* + #include\s\"config\.h\"(\s|.)* namespace\s+awesome\s+{\s+.+\s+ #define\sPI\s3\.1415\s+ using\snamespace\sstd\;\s+ @@ -48,6 +59,10 @@ Conditionals: - regex: if\s*\(\s*noodle\s==\sNoodles::RIGATONI\s\|\|\snoodle\s==\sNoodles::SPAGHETTI\s*\) count: 1 +NestedNamespaces: + sfg-args: + output-mode: header-only + # Kernel Generation ScaleKernel: diff --git a/tests/generator_scripts/source/BasicDefinitions.py b/tests/generator_scripts/source/BasicDefinitions.py index 7cfe352910b676429b97cb8f3d29bec68b74810a..4453066583348e6fad77ae89cd071fccfc5ce19d 100644 --- a/tests/generator_scripts/source/BasicDefinitions.py +++ b/tests/generator_scripts/source/BasicDefinitions.py @@ -5,11 +5,11 @@ cfg = SfgConfig() cfg.clang_format.skip = True with SourceFileGenerator(cfg) as sfg: + sfg.namespace("awesome") + sfg.prelude("Expect the unexpected, and you shall never be surprised.") sfg.include("<iostream>") sfg.include("config.h") - sfg.namespace("awesome") - sfg.code("#define PI 3.1415") sfg.code("using namespace std;") diff --git a/tests/generator_scripts/source/JacobiMdspan.py b/tests/generator_scripts/source/JacobiMdspan.py index bbe95ac272edbf3b4d9711088c91168cdb525d54..2e0741a046d317ed41d091524c0ab6b855318f46 100644 --- a/tests/generator_scripts/source/JacobiMdspan.py +++ b/tests/generator_scripts/source/JacobiMdspan.py @@ -15,7 +15,9 @@ with SourceFileGenerator() as sfg: @kernel def poisson_jacobi(): - u_dst[0,0] @= (h**2 * f[0, 0] + u_src[1, 0] + u_src[-1, 0] + u_src[0, 1] + u_src[0, -1]) / 4 + u_dst[0, 0] @= ( + h**2 * f[0, 0] + u_src[1, 0] + u_src[-1, 0] + u_src[0, 1] + u_src[0, -1] + ) / 4 poisson_kernel = sfg.kernels.create(poisson_jacobi) @@ -23,5 +25,5 @@ with SourceFileGenerator() as sfg: sfg.map_field(u_src, mdspan.from_field(u_src, layout_policy="layout_left")), sfg.map_field(u_dst, mdspan.from_field(u_dst, layout_policy="layout_left")), sfg.map_field(f, mdspan.from_field(f, layout_policy="layout_left")), - sfg.call(poisson_kernel) + sfg.call(poisson_kernel), ) diff --git a/tests/generator_scripts/source/NestedNamespaces.harness.cpp b/tests/generator_scripts/source/NestedNamespaces.harness.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ea7c465716a0d882b44cd6fe0a8d17c35dd94c02 --- /dev/null +++ b/tests/generator_scripts/source/NestedNamespaces.harness.cpp @@ -0,0 +1,12 @@ +#include "NestedNamespaces.hpp" + +static_assert( outer::X == 13 ); +static_assert( outer::inner::Y == 52 ); +static_assert( outer::Z == 41 ); +static_assert( outer::second_inner::W == 91 ); +static_assert( outer::inner::innermost::V == 29 ); +static_assert( GLOBAL == 42 ); + +int main() { + return 0; +} diff --git a/tests/generator_scripts/source/NestedNamespaces.py b/tests/generator_scripts/source/NestedNamespaces.py new file mode 100644 index 0000000000000000000000000000000000000000..4af7bc7b55f729157ddadbf361bc2b77db60c975 --- /dev/null +++ b/tests/generator_scripts/source/NestedNamespaces.py @@ -0,0 +1,19 @@ +from pystencilssfg import SourceFileGenerator + +with SourceFileGenerator() as sfg: + + with sfg.namespace("outer"): + sfg.code("constexpr int X = 13;") + + with sfg.namespace("inner"): + sfg.code("constexpr int Y = 52;") + + sfg.code("constexpr int Z = 41;") + + with sfg.namespace("outer::second_inner"): + sfg.code("constexpr int W = 91;") + + with sfg.namespace("outer::inner::innermost"): + sfg.code("constexpr int V = 29;") + + sfg.code("constexpr int GLOBAL = 42;") diff --git a/tests/generator_scripts/source/ScaleKernel.py b/tests/generator_scripts/source/ScaleKernel.py index 8bcc75fb7c98e8d46602c7f3f888650d8b8e011c..2242a3bc34f8edf0fd6ecff6a8c1bd14acf4b0fd 100644 --- a/tests/generator_scripts/source/ScaleKernel.py +++ b/tests/generator_scripts/source/ScaleKernel.py @@ -3,6 +3,8 @@ from pystencils import TypedSymbol, fields, kernel from pystencilssfg import SourceFileGenerator with SourceFileGenerator() as sfg: + sfg.namespace("gen") + N = 10 α = TypedSymbol("alpha", "float32") src, dst = fields(f"src, dst: float32[{N}]") @@ -13,7 +15,6 @@ with SourceFileGenerator() as sfg: khandle = sfg.kernels.create(scale) - sfg.namespace("gen") sfg.code(f"constexpr int N = {N};") sfg.klass("Scale")( diff --git a/tests/generator_scripts/source/StlContainers1D.py b/tests/generator_scripts/source/StlContainers1D.py index 3f6ec2c953a6537bef9785d837a2def88439972a..91b29110b2aeff7c713c2ea8e89482ecca9bb388 100644 --- a/tests/generator_scripts/source/StlContainers1D.py +++ b/tests/generator_scripts/source/StlContainers1D.py @@ -10,21 +10,18 @@ with SourceFileGenerator() as sfg: src, dst = ps.fields("src, dst: double[1D]") - asms = [ - ps.Assignment(dst[0], sp.Rational(1, 3) * (src[-1] + src[0] + src[1])) - ] + asms = [ps.Assignment(dst[0], sp.Rational(1, 3) * (src[-1] + src[0] + src[1]))] kernel = sfg.kernels.create(asms, "average") sfg.function("averageVector")( sfg.map_field(src, std.vector.from_field(src)), sfg.map_field(dst, std.vector.from_field(dst)), - sfg.call(kernel) + sfg.call(kernel), ) sfg.function("averageSpan")( sfg.map_field(src, std.span.from_field(src)), sfg.map_field(dst, std.span.from_field(dst)), - sfg.call(kernel) + sfg.call(kernel), ) - diff --git a/tests/generator_scripts/source/TestIncludeSorting.py b/tests/generator_scripts/source/TestIncludeSorting.py new file mode 100644 index 0000000000000000000000000000000000000000..8a584f6b0f4a836b50d862b2d3a161ce98caa517 --- /dev/null +++ b/tests/generator_scripts/source/TestIncludeSorting.py @@ -0,0 +1,23 @@ +from pystencilssfg import SourceFileGenerator, SfgConfig +from pystencilssfg.lang import HeaderFile + + +def sortkey(h: HeaderFile): + try: + return [ + "memory", + "vector", + "array" + ].index(h.filepath) + except ValueError: + return 100 + + +cfg = SfgConfig() +cfg.codestyle.includes_sorting_key = sortkey + + +with SourceFileGenerator(cfg) as sfg: + sfg.include("<array>") + sfg.include("<memory>") + sfg.include("<vector>") diff --git a/tests/integration/cmake_project/GenTest.py b/tests/integration/cmake_project/GenTest.py index 8399e7061ae79b5b3a8c0fe6c3d544f0b5d6f586..81aec18e250e33cb1969565d80d09ee31206e69e 100644 --- a/tests/integration/cmake_project/GenTest.py +++ b/tests/integration/cmake_project/GenTest.py @@ -2,7 +2,6 @@ from pystencilssfg import SourceFileGenerator with SourceFileGenerator() as sfg: sfg.namespace("gen") - retval = 42 if sfg.context.project_info is None else sfg.context.project_info sfg.function("getValue", return_type="int")( diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py index 070743ae6ce7e63ce4c825142ed28fb6052d647f..9d51c8fa6ef944a51d9c60219c1460061d6514b1 100644 --- a/tests/ir/test_postprocessing.py +++ b/tests/ir/test_postprocessing.py @@ -2,7 +2,6 @@ import sympy as sp from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type from pystencils.types import PsCustomType -from pystencilssfg import SfgContext, SfgComposer from pystencilssfg.composer import make_sequence from pystencilssfg.lang import IFieldExtraction, AugExpr @@ -11,10 +10,7 @@ from pystencilssfg.ir import SfgStatements, SfgSequence from pystencilssfg.ir.postprocessing import CallTreePostProcessing -def test_live_vars(): - ctx = SfgContext() - sfg = SfgComposer(ctx) - +def test_live_vars(sfg): f, g = fields("f, g(2): double[2D]") x, y = [TypedSymbol(n, "double") for n in "xy"] z = sp.Symbol("z") @@ -42,10 +38,7 @@ def test_live_vars(): assert free_vars == expected -def test_find_sympy_symbols(): - ctx = SfgContext() - sfg = SfgComposer(ctx) - +def test_find_sympy_symbols(sfg): f, g = fields("f, g(2): double[2D]") x, y, z = sp.symbols("x, y, z") @@ -94,7 +87,7 @@ class DemoFieldExtraction(IFieldExtraction): return AugExpr.format("{}.stride({})", self.obj, coordinate) -def test_field_extraction(): +def test_field_extraction(sfg): sx, sy, tx, ty = [ TypedSymbol(n, create_type("int64")) for n in ("sx", "sy", "tx", "ty") ] @@ -104,8 +97,6 @@ def test_field_extraction(): def set_constant(): f.center @= 13.2 - sfg = SfgComposer(SfgContext()) - khandle = sfg.kernels.create(set_constant) extraction = DemoFieldExtraction("f") @@ -129,7 +120,7 @@ def test_field_extraction(): assert stmt.code_string == line -def test_duplicate_field_shapes(): +def test_duplicate_field_shapes(sfg): N, tx, ty = [TypedSymbol(n, create_type("int64")) for n in ("N", "tx", "ty")] f = Field("f", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty)) g = Field("g", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty)) @@ -138,8 +129,6 @@ def test_duplicate_field_shapes(): def set_constant(): f.center @= g.center(0) - sfg = SfgComposer(SfgContext()) - khandle = sfg.kernels.create(set_constant) call_tree = make_sequence(