diff --git a/src/pystencilssfg/__init__.py b/src/pystencilssfg/__init__.py index b2def3b84ba3eab0706b095878a9c165322c1d8b..fea6f8a10e198f86f18be34a6cad381ed3238e19 100644 --- a/src/pystencilssfg/__init__.py +++ b/src/pystencilssfg/__init__.py @@ -1,5 +1,5 @@ -from .config import SfgConfig -from .generator import SourceFileGenerator, GLOBAL_NAMESPACE, OutputMode +from .config import SfgConfig, GLOBAL_NAMESPACE, OutputMode +from .generator import SourceFileGenerator from .composer import SfgComposer from .context import SfgContext from .lang import SfgVar, AugExpr diff --git a/src/pystencilssfg/config.py b/src/pystencilssfg/config.py index 18aaa515143547bc03b438c5dda4e0169fc16d93..a94d9ad084cc71e407f57124f86e8863beeb02e0 100644 --- a/src/pystencilssfg/config.py +++ b/src/pystencilssfg/config.py @@ -3,7 +3,7 @@ from __future__ import annotations from argparse import ArgumentParser from types import ModuleType -from typing import Any, Sequence +from typing import Any, Sequence, Callable from dataclasses import dataclass from enum import Enum, auto from os import path @@ -12,6 +12,8 @@ from pathlib import Path from pystencils.codegen.config import ConfigBase, Option, BasicOption, Category +from .lang import HeaderFile + class SfgConfigException(Exception): ... # noqa: E701 @@ -61,6 +63,12 @@ class CodeStyle(ConfigBase): indent_width: BasicOption[int] = BasicOption(2) """The number of spaces successively nested blocks should be indented with""" + includes_sorting_key: BasicOption[Callable[[HeaderFile], Any]] = BasicOption() + """Key function that will be used to sort `#include` statements in generated files. + + Pystencils-sfg will instruct clang-tidy to forego include sorting if this option is set. + """ + # TODO possible future options: # - newline before opening { # - trailing return types diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index 032c1e4ff275abb45be03de24391e38a0a9d4429..24c38e1f26e50f4569da919ebfb272ae8251baac 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -20,7 +20,7 @@ class SfgContext: self, header_file: SfgSourceFile, impl_file: SfgSourceFile | None, - outer_namespace: str | None = None, + namespace: str | None = None, codestyle: CodeStyle | None = None, argv: Sequence[str] | None = None, project_info: Any = None, @@ -28,7 +28,7 @@ class SfgContext: self._argv = argv self._project_info = project_info - self._outer_namespace = outer_namespace + self._outer_namespace = namespace self._inner_namespace: str | None = None self._codestyle = codestyle if codestyle is not None else CodeStyle() @@ -39,8 +39,8 @@ class SfgContext: self._global_namespace = SfgGlobalNamespace() current_ns: SfgNamespace = self._global_namespace - if outer_namespace is not None: - for token in outer_namespace.split("::"): + if namespace is not None: + for token in namespace.split("::"): current_ns = SfgNamespace(token, current_ns) self._cursor = SfgCursor(self, current_ns) @@ -104,8 +104,10 @@ class SfgCursor: self._loc: dict[SfgSourceFile, list[SfgNamespaceElement]] = dict() for f in self._ctx.files: - if self._cur_namespace is not None: - block = SfgNamespaceBlock(self._cur_namespace) + if not isinstance(namespace, SfgGlobalNamespace): + block = SfgNamespaceBlock( + self._cur_namespace, self._cur_namespace.fqname + ) f.elements.append(block) self._loc[f] = block.elements else: diff --git a/src/pystencilssfg/emission/clang_format.py b/src/pystencilssfg/emission/clang_format.py index 1b15e8c4145db67656b29acdd456074ad2d0a895..b73d9da973bc107a56f7fef319ee8c0ce3ad5f5f 100644 --- a/src/pystencilssfg/emission/clang_format.py +++ b/src/pystencilssfg/emission/clang_format.py @@ -5,14 +5,16 @@ from ..config import ClangFormatOptions from ..exceptions import SfgException -def invoke_clang_format(code: str, options: ClangFormatOptions) -> str: +def invoke_clang_format( + code: str, options: ClangFormatOptions, sort_includes: str | None = None +) -> str: """Call the `clang-format` command-line tool to format the given code string according to the given style arguments. Args: code: Code string to format - codestyle: [SfgCodeStyle][pystencilssfg.configuration.SfgCodeStyle] object - defining the `clang-format` binary and the desired code style. + options: Options controlling the clang-format invocation + sort_includes: Option to be passed on to clang-format's ``--sort-includes`` argument Returns: The formatted code, if `clang-format` was run sucessfully. @@ -31,6 +33,9 @@ def invoke_clang_format(code: str, options: ClangFormatOptions) -> str: force = options.get_option("force") style = options.get_option("code_style") args = [binary, f"--style={style}"] + + if sort_includes is not None: + args += ["--sort-includes", sort_includes] if not shutil.which(binary): if force: diff --git a/src/pystencilssfg/emission/emitter.py b/src/pystencilssfg/emission/emitter.py index 440344b0d188fb5a2a3b987b9fcbbf029df3a9d5..c1b6e9c79e09bf1d5604dbd6fca9304190e85272 100644 --- a/src/pystencilssfg/emission/emitter.py +++ b/src/pystencilssfg/emission/emitter.py @@ -17,12 +17,21 @@ class SfgCodeEmitter: clang_format: ClangFormatOptions, ): self._output_dir = output_directory + self._code_style = code_style self._clang_format_opts = clang_format self._printer = SfgFilePrinter(code_style) def emit(self, file: SfgSourceFile): code = self._printer(file) - code = invoke_clang_format(code, self._clang_format_opts) + + if self._code_style.get_option("includes_sorting_key") is not None: + sort_includes = "Never" + else: + sort_includes = None + + code = invoke_clang_format( + code, self._clang_format_opts, sort_includes=sort_includes + ) self._output_dir.mkdir(parents=True, exist_ok=True) fpath = self._output_dir / file.name diff --git a/src/pystencilssfg/emission/file_printer.py b/src/pystencilssfg/emission/file_printer.py index f92fba468be8a440639f1ed245cf6bb99bde12b7..8216a7b1923b8b23e19a4fbece8d507aa6260d27 100644 --- a/src/pystencilssfg/emission/file_printer.py +++ b/src/pystencilssfg/emission/file_printer.py @@ -63,12 +63,12 @@ class SfgFilePrinter: match elem: case str(): return elem - case SfgNamespaceBlock(namespace, elements): - code = f"namespace {namespace.name} {{\n" + case SfgNamespaceBlock(_, elements, label): + code = f"namespace {label} {{\n" code += self._code_style.indent( "\n\n".join(self.visit(e) for e in elements) ) - code += f"\n}} // namespace {namespace.name}" + code += f"\n}} // namespace {label}" return code case SfgEntityDecl(entity): return self.visit_decl(entity, inclass) @@ -157,7 +157,7 @@ class SfgFilePrinter: code = f"{cls.class_keyword} {cls.name} {{\n" vblocks_str = [self._visibility_block(b) for b in vblocks] code += "\n\n".join(vblocks_str) - code += "\n}\n" + code += "\n};\n" return code case _: diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py index 1191ebeea9ca3633afe1cf92635c7c55f8abe692..dd9a78c61d73bb37dd9ca19d64a07187e87735f1 100644 --- a/src/pystencilssfg/generator.py +++ b/src/pystencilssfg/generator.py @@ -1,5 +1,6 @@ from pathlib import Path +from typing import Callable, Any from .config import ( SfgConfig, CommandLineParameters, @@ -102,11 +103,16 @@ class SourceFileGenerator: output_files[1].name, SfgSourceFileType.HEADER ) + # TODO: Find a way to not hard-code the restrict qualifier in pystencils + self._header_file.elements.append("#define RESTRICT __restrict__") + outer_namespace: str | _GlobalNamespace = config.get_option("outer_namespace") match (outer_namespace, namespace): case [_GlobalNamespace(), None]: namespace = None - case [_GlobalNamespace(), nspace] | [nspace, None]: + case [_GlobalNamespace(), nspace] if nspace is not None: + namespace = nspace + case [nspace, None]: namespace = nspace case [outer, inner]: namespace = f"{outer}::{inner}" @@ -124,6 +130,16 @@ class SourceFileGenerator: self._output_dir, config.codestyle, config.clang_format ) + sort_key = config.codestyle.get_option("includes_sorting_key") + if sort_key is None: + + def default_key(h: HeaderFile): + return str(h) + + sort_key = default_key + + self._include_sort_key: Callable[[HeaderFile], Any] = sort_key + def clean_files(self): header_path = self._output_dir / self._header_file.name if header_path.exists(): @@ -144,12 +160,22 @@ class SourceFileGenerator: assert self._impl_file is not None self._header_file.elements.append(f'#include "{self._impl_file.name}"') - # Collect header files for inclusion - # from .ir import collect_includes + from .ir import collect_includes + + header_includes = collect_includes(self._header_file) + self._header_file.includes = list( + set(self._header_file.includes) | header_includes + ) + self._header_file.includes.sort(key=self._include_sort_key) - # TODO: Collect headers - # for header in collect_includes(self._context): - # self._context.add_include(SfgHeaderInclude(header)) + if self._impl_file is not None: + impl_includes = collect_includes(self._impl_file) + # If some header is already included by the generated header file, do not duplicate that inclusion + impl_includes -= header_includes + self._impl_file.includes = list( + set(self._impl_file.includes) | impl_includes + ) + self._impl_file.includes.sort(key=self._include_sort_key) self._emitter.emit(self._header_file) if self._impl_file is not None: diff --git a/src/pystencilssfg/ir/analysis.py b/src/pystencilssfg/ir/analysis.py index b88c4f3fe3e2265c281734ec8725536cb609c130..a2bce074fb707304ed8431d9fa1ddbe66118dbe1 100644 --- a/src/pystencilssfg/ir/analysis.py +++ b/src/pystencilssfg/ir/analysis.py @@ -1,54 +1,103 @@ from __future__ import annotations -from typing import Any from functools import reduce -from ..exceptions import SfgException from ..lang import HeaderFile, includes +from .syntax import ( + SfgSourceFile, + SfgNamespaceElement, + SfgClassBodyElement, + SfgNamespaceBlock, + SfgEntityDecl, + SfgEntityDef, + SfgClassBody, + SfgVisibilityBlock, +) -def collect_includes(obj: Any) -> set[HeaderFile]: +def collect_includes(file: SfgSourceFile) -> set[HeaderFile]: from .call_tree import SfgCallTreeNode from .entities import ( + SfgCodeEntity, + SfgKernelHandle, SfgFunction, - SfgClass, + SfgMethod, + SfgClassMember, SfgConstructor, SfgMemberVariable, ) - match obj: - # TODO - - case SfgCallTreeNode(): - return reduce( - lambda accu, child: accu | collect_includes(child), - obj.children, - obj.required_includes, - ) - - case SfgFunction(_, tree, parameters): - param_headers: set[HeaderFile] = reduce( - set.union, (includes(p) for p in parameters), set() - ) - return param_headers | collect_includes(tree) - - case SfgClass(): - return reduce( - lambda accu, member: accu | (collect_includes(member)), - obj.members(), - set(), - ) - - case SfgConstructor(_, parameters): - param_headers = reduce( - set.union, (includes(p) for p in parameters), set() - ) - return param_headers - - case SfgMemberVariable(): - return includes(obj) - - case _: - raise SfgException( - f"Can't collect includes from object of type {type(obj)}" - ) + def visit_decl(entity: SfgCodeEntity | SfgClassMember) -> set[HeaderFile]: + match entity: + case ( + SfgKernelHandle(_, parameters) + | SfgFunction(_, _, parameters) + | SfgMethod(_, _, parameters) + | SfgConstructor(_, parameters, _, _) + ): + incls = reduce(set.union, (includes(p) for p in parameters), set()) + if isinstance(entity, (SfgFunction, SfgMethod)): + incls |= includes(entity.return_type) + return incls + + case SfgMemberVariable(): + return includes(entity) + + case _: + assert False, "unexpected entity" + + def walk_syntax( + obj: ( + SfgNamespaceElement + | SfgClassBodyElement + | SfgVisibilityBlock + | SfgCallTreeNode + ), + ) -> set[HeaderFile]: + match obj: + case str(): + return set() + + case SfgCallTreeNode(): + return reduce( + lambda accu, child: accu | walk_syntax(child), + obj.children, + obj.required_includes, + ) + + case SfgEntityDecl(entity): + return visit_decl(entity) + + case SfgEntityDef(entity): + match entity: + case SfgKernelHandle(kernel, _): + return set( + HeaderFile.parse(h) for h in kernel.required_headers + ) | visit_decl(entity) + + case SfgFunction(_, tree, _) | SfgMethod(_, tree, _): + return walk_syntax(tree) | visit_decl(entity) + + case SfgConstructor(): + return visit_decl(entity) + + case SfgMemberVariable(): + return includes(entity) + + case _: + assert False, "unexpected entity" + + case SfgNamespaceBlock(_, elements) | SfgVisibilityBlock(_, elements): + return reduce( + lambda accu, elem: accu | walk_syntax(elem), elements, set() + ) + + case SfgClassBody(_, vblocks): + return reduce( + lambda accu, vblock: accu | walk_syntax(vblock), vblocks, set() + ) + + case _: + assert False, "unexpected syntax element" + + return reduce(lambda accu, elem: accu | walk_syntax(elem), file.elements, set()) diff --git a/src/pystencilssfg/ir/entities.py b/src/pystencilssfg/ir/entities.py index 2b0e3e5c3e5fd2f9db9f49922fed3bdedad331ef..90205fea5515a4d6b8be58cc97943eebbdfb0e03 100644 --- a/src/pystencilssfg/ir/entities.py +++ b/src/pystencilssfg/ir/entities.py @@ -100,7 +100,7 @@ class SfgGlobalNamespace(SfgNamespace): class SfgKernelHandle(SfgCodeEntity): """Handle to a pystencils kernel.""" - __match_args__ = ("kernel",) + __match_args__ = ("kernel", "parameters") def __init__(self, name: str, namespace: SfgKernelNamespace, kernel: Kernel): super().__init__(name, namespace) @@ -219,7 +219,7 @@ class SfgKernelNamespace(SfgNamespace): class SfgFunction(SfgCodeEntity): """A free function.""" - __match_args__ = ("name", "tree", "parameters") + __match_args__ = ("name", "tree", "parameters", "return_type") def __init__( self, @@ -337,7 +337,7 @@ class SfgMemberVariable(SfgVar, SfgClassMember): class SfgMethod(SfgClassMember): """Instance method of a class""" - __match_args__ = ("name", "tree", "parameters") + __match_args__ = ("name", "tree", "parameters", "return_type") def __init__( self, diff --git a/src/pystencilssfg/ir/syntax.py b/src/pystencilssfg/ir/syntax.py index 574c02930d224b6f130ad7b7b337abdb41af36c2..699e7b58f679287620117cb451bd4fecc3d20d8a 100644 --- a/src/pystencilssfg/ir/syntax.py +++ b/src/pystencilssfg/ir/syntax.py @@ -80,6 +80,8 @@ class SfgVisibilityBlock: visibility: The visibility qualifier of this block """ + __match_args__ = ("visibility", "elements") + def __init__(self, visibility: SfgVisibility) -> None: self._vis = visibility self._elements: list[SfgClassBodyElement] = [] @@ -107,29 +109,34 @@ class SfgVisibilityBlock: class SfgNamespaceBlock: - """A C++ namespace. - - Each namespace has a `name` and a `parent`; its fully qualified name is given as - ``<parent.name>::<name>``. + """A C++ namespace block. Args: - name: Local name of this namespace - parent: Parent namespace enclosing this namespace + namespace: Namespace associated with this block + label: Label printed at the opening brace of this block. + This may be the namespace name, or a compressed qualified + name containing one or more of its parent namespaces. """ __match_args__ = ( "namespace", "elements", + "label", ) - def __init__(self, namespace: SfgNamespace) -> None: + def __init__(self, namespace: SfgNamespace, label: str | None = None) -> None: self._namespace = namespace + self._label = label if label is not None else namespace.name self._elements: list[SfgNamespaceElement] = [] @property def namespace(self) -> SfgNamespace: return self._namespace + @property + def label(self) -> str: + return self._label + @property def elements(self) -> list[SfgNamespaceElement]: """Sequence of source elements that make up the body of this namespace""" diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 4a1f7e9e7aa49d0534f4d27e3b038eb7c72d3c25..b3ed18e276c2485702247a651d3d9a7c593d594c 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -7,7 +7,7 @@ import sympy as sp from pystencils import TypedSymbol from pystencils.codegen import Parameter -from pystencils.types import PsType, UserTypeSpec, create_type +from pystencils.types import PsType, PsScalarType, UserTypeSpec, create_type from ..exceptions import SfgException from .headers import HeaderFile @@ -452,7 +452,7 @@ def depends(expr: ExprLike) -> set[SfgVar]: raise ValueError(f"Invalid expression: {expr}") -def includes(expr: ExprLike) -> set[HeaderFile]: +def includes(obj: ExprLike | PsType) -> set[HeaderFile]: """Determine the set of header files an expression depends on. Args: @@ -465,21 +465,33 @@ def includes(expr: ExprLike) -> set[HeaderFile]: ValueError: If the argument was not a valid variable or expression """ - match expr: + if isinstance(obj, PsType): + obj = strip_ptr_ref(obj) + + match obj: + case CppType(): + return set(obj.includes) + + case PsType(): + headers = set(HeaderFile.parse(h) for h in obj.required_headers) + if isinstance(obj, PsScalarType): + headers.add(HeaderFile.parse("<cstdint>")) + return headers + case SfgVar(_, dtype): - match dtype: - case CppType(): - return set(dtype.includes) - case _: - return set(HeaderFile.parse(h) for h in dtype.required_headers) + return includes(dtype) + case TypedSymbol(): - return includes(asvar(expr)) + return includes(asvar(obj)) + case str(): return set() + case AugExpr(): - return set(expr.includes) + return set(obj.includes) + case _: - raise ValueError(f"Invalid expression: {expr}") + raise ValueError(f"Invalid expression: {obj}") class IFieldExtraction(ABC):