From 80f224a1efae0427661847884d2493ab2d03121f Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 8 Jan 2025 11:25:00 +0100 Subject: [PATCH] Add header inheritance to CppType. Reintroduce TypeClassFactory for creating refs. Add new `ref=` API to STL classes. Fix some minor bugs. --- src/pystencilssfg/composer/basic_composer.py | 32 +----------- src/pystencilssfg/composer/class_composer.py | 4 +- src/pystencilssfg/emission/printers.py | 2 +- src/pystencilssfg/extensions/sycl.py | 4 +- src/pystencilssfg/lang/cpp/std_mdspan.py | 7 +-- src/pystencilssfg/lang/cpp/std_span.py | 7 +-- src/pystencilssfg/lang/cpp/std_tuple.py | 7 +-- src/pystencilssfg/lang/cpp/std_vector.py | 6 +-- src/pystencilssfg/lang/cpp/sycl_accessor.py | 6 +-- src/pystencilssfg/lang/expressions.py | 8 ++- src/pystencilssfg/lang/types.py | 52 ++++++++++++------- .../generator_scripts/source/Conditionals.py | 3 +- .../generator_scripts/source/SimpleClasses.py | 2 +- tests/lang/test_types.py | 14 ++++- 14 files changed, 71 insertions(+), 83 deletions(-) diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 90146e5..133f504 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -7,13 +7,7 @@ from functools import reduce from pystencils import Field from pystencils.backend import KernelFunction -from pystencils.types import ( - create_type, - UserTypeSpec, - PsCustomType, - PsPointerType, - PsType, -) +from pystencils.types import create_type, UserTypeSpec from ..context import SfgContext from .custom import CustomGenerator @@ -325,30 +319,6 @@ class SfgBasicComposer(SfgIComposer): """Use inside a function body to require the inclusion of headers.""" return SfgRequireIncludes((HeaderFile.parse(incl) for incl in incls)) - def cpptype( - self, - typename: UserTypeSpec, - ptr: bool = False, - ref: bool = False, - const: bool = False, - ) -> PsType: - if ptr and ref: - raise SfgException("Create either a pointer, or a ref type, not both!") - - ref_qual = "&" if ref else "" - try: - base_type = create_type(typename) - except ValueError: - if not isinstance(typename, str): - raise ValueError(f"Could not parse type: {typename}") - - base_type = PsCustomType(typename + ref_qual, const=const) - - if ptr: - return PsPointerType(base_type) - else: - return base_type - def var(self, name: str, dtype: UserTypeSpec) -> AugExpr: """Create a variable with given name and data type.""" return AugExpr(create_type(dtype)).var(name) diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index 1f4c486..489823b 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Sequence -from pystencils.types import PsCustomType, UserTypeSpec +from pystencils.types import PsCustomType, UserTypeSpec, create_type from ..lang import ( _VarLike, @@ -177,7 +177,7 @@ class SfgClassComposer(SfgComposerMixIn): return SfgMethod( name, tree, - return_type=self._composer.cpptype(returns), + return_type=create_type(returns), inline=inline, const=const, ) diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py index eb05ff6..3dc4a8f 100644 --- a/src/pystencilssfg/emission/printers.py +++ b/src/pystencilssfg/emission/printers.py @@ -166,7 +166,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgMethod) def sfg_method(self, method: SfgMethod): - code = f"{method.return_type} {method.name} ({self.param_list(method)})" + code = f"{method.return_type.c_string()} {method.name} ({self.param_list(method)})" code += "const" if method.const else "" if method.inline: code += ( diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index 349e030..88dbc9b 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -57,9 +57,7 @@ class SyclRange(AugExpr): _template = cpptype("sycl::range< {dims} >", "<sycl/sycl.hpp>") def __init__(self, dims: int, const: bool = False, ref: bool = False): - dtype = self._template(dims=dims, const=const) - if ref: - dtype = Ref(dtype) + dtype = self._template(dims=dims, const=const, ref=ref) super().__init__(dtype) diff --git a/src/pystencilssfg/lang/cpp/std_mdspan.py b/src/pystencilssfg/lang/cpp/std_mdspan.py index 5f80ad6..6830885 100644 --- a/src/pystencilssfg/lang/cpp/std_mdspan.py +++ b/src/pystencilssfg/lang/cpp/std_mdspan.py @@ -11,7 +11,7 @@ from pystencils.types import ( from pystencilssfg.lang.expressions import AugExpr -from ...lang import SrcField, IFieldExtraction, cpptype, Ref, HeaderFile, ExprLike +from ...lang import SrcField, IFieldExtraction, cpptype, HeaderFile, ExprLike class StdMdspan(SrcField): @@ -111,11 +111,8 @@ class StdMdspan(SrcField): layout_policy = f"{self._namespace}::{layout_policy}" dtype = self._template( - T=T, extents=extents_str, layout_policy=layout_policy, const=const + T=T, extents=extents_str, layout_policy=layout_policy, const=const, ref=ref ) - - if ref: - dtype = Ref(dtype) super().__init__(dtype) self._extents_type = extents_str diff --git a/src/pystencilssfg/lang/cpp/std_span.py b/src/pystencilssfg/lang/cpp/std_span.py index 861a4c4..f161f48 100644 --- a/src/pystencilssfg/lang/cpp/std_span.py +++ b/src/pystencilssfg/lang/cpp/std_span.py @@ -1,7 +1,7 @@ from pystencils.field import Field from pystencils.types import UserTypeSpec, create_type, PsType -from ...lang import SrcField, IFieldExtraction, AugExpr, cpptype, Ref +from ...lang import SrcField, IFieldExtraction, AugExpr, cpptype class StdSpan(SrcField): @@ -9,10 +9,7 @@ class StdSpan(SrcField): def __init__(self, T: UserTypeSpec, ref=False, const=False): T = create_type(T) - dtype = self._template(T=T, const=const) - if ref: - dtype = Ref(dtype) - + dtype = self._template(T=T, const=const, ref=ref) self._element_type = T super().__init__(dtype) diff --git a/src/pystencilssfg/lang/cpp/std_tuple.py b/src/pystencilssfg/lang/cpp/std_tuple.py index bbf2ba3..58a3530 100644 --- a/src/pystencilssfg/lang/cpp/std_tuple.py +++ b/src/pystencilssfg/lang/cpp/std_tuple.py @@ -2,7 +2,7 @@ from typing import Sequence from pystencils.types import UserTypeSpec, create_type -from ...lang import SrcVector, AugExpr, cpptype, Ref +from ...lang import SrcVector, AugExpr, cpptype class StdTuple(SrcVector): @@ -18,10 +18,7 @@ class StdTuple(SrcVector): self._length = len(element_types) elt_type_strings = tuple(t.c_string() for t in self._element_types) - dtype = self._template(ts=", ".join(elt_type_strings), const=const) - if ref: - dtype = Ref(dtype) - + dtype = self._template(ts=", ".join(elt_type_strings), const=const, ref=ref) super().__init__(dtype) def extract_component(self, coordinate: int) -> AugExpr: diff --git a/src/pystencilssfg/lang/cpp/std_vector.py b/src/pystencilssfg/lang/cpp/std_vector.py index 5696b32..7e9291e 100644 --- a/src/pystencilssfg/lang/cpp/std_vector.py +++ b/src/pystencilssfg/lang/cpp/std_vector.py @@ -1,7 +1,7 @@ from pystencils.field import Field from pystencils.types import UserTypeSpec, create_type, PsType -from ...lang import SrcField, SrcVector, AugExpr, IFieldExtraction, cpptype, Ref +from ...lang import SrcField, SrcVector, AugExpr, IFieldExtraction, cpptype class StdVector(SrcVector, SrcField): @@ -15,9 +15,7 @@ class StdVector(SrcVector, SrcField): const: bool = False, ): T = create_type(T) - dtype = self._template(T=T, const=const) - if ref: - dtype = Ref(dtype) + dtype = self._template(T=T, const=const, ref=ref) super().__init__(dtype) self._element_type = T diff --git a/src/pystencilssfg/lang/cpp/sycl_accessor.py b/src/pystencilssfg/lang/cpp/sycl_accessor.py index f01c53d..4bcad56 100644 --- a/src/pystencilssfg/lang/cpp/sycl_accessor.py +++ b/src/pystencilssfg/lang/cpp/sycl_accessor.py @@ -3,7 +3,7 @@ from ...lang import SrcField, IFieldExtraction from pystencils import Field from pystencils.types import UserTypeSpec, create_type -from ...lang import AugExpr, cpptype, Ref +from ...lang import AugExpr, cpptype class SyclAccessor(SrcField): @@ -29,9 +29,7 @@ class SyclAccessor(SrcField): T = create_type(T) if dimensions > 3: raise ValueError("sycl accessors can only have dims 1, 2 or 3") - dtype = self._template(T=T, dims=dimensions, const=const) - if ref: - dtype = Ref(dtype) + dtype = self._template(T=T, dims=dimensions, const=const, ref=ref) super().__init__(dtype) diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 53064bd..03818c6 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -147,7 +147,7 @@ class VarExpr(DependentExpression): incls: Iterable[HeaderFile] match base_type: case CppType(): - incls = base_type.includes + incls = base_type.class_includes case _: incls = ( HeaderFile.parse(header) for header in var.dtype.required_headers @@ -408,7 +408,11 @@ def includes(expr: ExprLike) -> set[HeaderFile]: match expr: case SfgVar(_, dtype): - return set(HeaderFile.parse(h) for h in dtype.required_headers) + match dtype: + case CppType(): + return set(dtype.includes) + case _: + return set(HeaderFile.parse(h) for h in dtype.required_headers) case TypedSymbol(): return includes(asvar(expr)) case str(): diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py index e437ae6..5f064d5 100644 --- a/src/pystencilssfg/lang/types.py +++ b/src/pystencilssfg/lang/types.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Any, Iterable, Sequence, Mapping from abc import ABC from dataclasses import dataclass +from itertools import chain import string @@ -38,7 +39,9 @@ class _TemplateArgFormatter(string.Formatter): def check_unused_args( self, used_args: set[int | str], args: Sequence, kwargs: Mapping[str, Any] ) -> None: - max_args_len: int = max((k for k in used_args if isinstance(k, int)), default=-1) + 1 + max_args_len: int = ( + max((k for k in used_args if isinstance(k, int)), default=-1) + 1 + ) if len(args) > max_args_len: raise ValueError( f"Too many positional arguments: Expected {max_args_len}, but got {len(args)}" @@ -56,21 +59,9 @@ class _TemplateArgs: class CppType(PsCustomType, ABC): - includes: frozenset[HeaderFile] + class_includes: frozenset[HeaderFile] template_string: str - def __new__( # type: ignore - cls, - *args, - ref: bool = False, - **kwargs, - ) -> CppType | Ref: - if ref: - obj = cls(*args, **kwargs) - return Ref(obj) - else: - return super().__new__(cls) - def __init__(self, *template_args, const: bool = False, **template_kwargs): # Support for cloning CppTypes if template_args and isinstance(template_args[0], _TemplateArgs): @@ -89,17 +80,34 @@ class CppType(PsCustomType, ABC): name = formatter.format(self.template_string, *pargs, **kwargs) self._targs = targs + self._includes = self.class_includes + + for arg in chain(pargs, kwargs.values()): + match arg: + case CppType(): + self._includes |= arg.includes + case PsType(): + self._includes |= { + HeaderFile.parse(h) for h in arg.required_headers + } + super().__init__(name, const=const) def __args__(self) -> tuple[Any, ...]: return (self._targs,) + @property + def includes(self) -> frozenset[HeaderFile]: + return self._includes + @property def required_headers(self) -> set[str]: - return set(str(h) for h in self.includes) + return set(str(h) for h in self.class_includes) -def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()): +def cpptype( + typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] = () +) -> type[CppType]: headers: list[str | HeaderFile] if isinstance(include, (str, HeaderFile)): @@ -111,9 +119,17 @@ def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] class TypeClass(CppType): template_string = typestr - includes = frozenset(HeaderFile.parse(h) for h in headers) + class_includes = frozenset(HeaderFile.parse(h) for h in headers) + + class TypeClassFactory(TypeClass): + def __new__(cls, *args, ref: bool = False, **kwargs): + obj = TypeClass(*args, **kwargs) + if ref: + return Ref(obj) + else: + return obj - return TypeClass + return TypeClassFactory class Ref(PsType): diff --git a/tests/generator_scripts/source/Conditionals.py b/tests/generator_scripts/source/Conditionals.py index d1088a9..9016b73 100644 --- a/tests/generator_scripts/source/Conditionals.py +++ b/tests/generator_scripts/source/Conditionals.py @@ -1,4 +1,5 @@ from pystencilssfg import SourceFileGenerator +from pystencils.types import PsCustomType with SourceFileGenerator() as sfg: sfg.namespace("gen") @@ -6,7 +7,7 @@ with SourceFileGenerator() as sfg: sfg.include("<iostream>") sfg.code(r"enum class Noodles { RIGATONI, RAMEN, SPAETZLE, SPAGHETTI };") - noodle = sfg.var("noodle", sfg.cpptype("Noodles")) + noodle = sfg.var("noodle", PsCustomType("Noodles")) sfg.function("printOpinion")( sfg.switch(noodle) diff --git a/tests/generator_scripts/source/SimpleClasses.py b/tests/generator_scripts/source/SimpleClasses.py index 64093f5..454f1a2 100644 --- a/tests/generator_scripts/source/SimpleClasses.py +++ b/tests/generator_scripts/source/SimpleClasses.py @@ -16,7 +16,7 @@ with SourceFileGenerator() as sfg: .init(y_)(y) .init(z_)(z), - sfg.method("getX", returns="const int64_t &", const=True, inline=True)( + sfg.method("getX", returns="const int64_t", const=True, inline=True)( "return this->x_;" ) ), diff --git a/tests/lang/test_types.py b/tests/lang/test_types.py index 35713ea..8bd0903 100644 --- a/tests/lang/test_types.py +++ b/tests/lang/test_types.py @@ -11,7 +11,7 @@ def test_cpptypes(): vec_type = tclass(create_type("float32"), "std::allocator< float >") assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >" assert ( - tclass.includes + tclass.class_includes == vec_type.includes == {HeaderFile("vector", system_header=True)} ) @@ -65,3 +65,15 @@ def test_cpptype_ref(): vec_type = tclass(T=create_type("uint32"), ref=True) assert isinstance(vec_type, Ref) assert strip_ptr_ref(vec_type) == tclass(T=create_type("uint32")) + + +def test_cpptype_inherits_headers(): + optional_tclass = cpptype("std::optional< {T} >", "<optional>") + vec_tclass = cpptype("std::vector< {T} >", "<vector>") + + vec_type = vec_tclass(T=optional_tclass(T="int")) + assert ( + vec_type.includes + == optional_tclass.class_includes | vec_tclass.class_includes + == {HeaderFile.parse("<optional>"), HeaderFile.parse("<vector>")} + ) -- GitLab