diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 90146e596123eedc73eea5f813964f7593b6f43f..133f504b140ee36dd7db7c7c01d8bb710fa9c8c5 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -7,13 +7,7 @@ from functools import reduce from pystencils import Field from pystencils.backend import KernelFunction -from pystencils.types import ( - create_type, - UserTypeSpec, - PsCustomType, - PsPointerType, - PsType, -) +from pystencils.types import create_type, UserTypeSpec from ..context import SfgContext from .custom import CustomGenerator @@ -325,30 +319,6 @@ class SfgBasicComposer(SfgIComposer): """Use inside a function body to require the inclusion of headers.""" return SfgRequireIncludes((HeaderFile.parse(incl) for incl in incls)) - def cpptype( - self, - typename: UserTypeSpec, - ptr: bool = False, - ref: bool = False, - const: bool = False, - ) -> PsType: - if ptr and ref: - raise SfgException("Create either a pointer, or a ref type, not both!") - - ref_qual = "&" if ref else "" - try: - base_type = create_type(typename) - except ValueError: - if not isinstance(typename, str): - raise ValueError(f"Could not parse type: {typename}") - - base_type = PsCustomType(typename + ref_qual, const=const) - - if ptr: - return PsPointerType(base_type) - else: - return base_type - def var(self, name: str, dtype: UserTypeSpec) -> AugExpr: """Create a variable with given name and data type.""" return AugExpr(create_type(dtype)).var(name) diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index 1f4c4865987c1f10ec133f95d11e4bccc1ef8b76..489823b9ce619be88e3220ce4b941cf49c62b298 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Sequence -from pystencils.types import PsCustomType, UserTypeSpec +from pystencils.types import PsCustomType, UserTypeSpec, create_type from ..lang import ( _VarLike, @@ -177,7 +177,7 @@ class SfgClassComposer(SfgComposerMixIn): return SfgMethod( name, tree, - return_type=self._composer.cpptype(returns), + return_type=create_type(returns), inline=inline, const=const, ) diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py index eb05ff6efa8eb8166c6d5f81e8e4596c6c797704..3dc4a8f74177c4de16158b1e6f314be25cca9c18 100644 --- a/src/pystencilssfg/emission/printers.py +++ b/src/pystencilssfg/emission/printers.py @@ -166,7 +166,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgMethod) def sfg_method(self, method: SfgMethod): - code = f"{method.return_type} {method.name} ({self.param_list(method)})" + code = f"{method.return_type.c_string()} {method.name} ({self.param_list(method)})" code += "const" if method.const else "" if method.inline: code += ( diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index 349e030c6c9057472b9f513258c8a979bba16b8a..88dbc9be2e215b1fdce5833ef18eac6eab336d74 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -57,9 +57,7 @@ class SyclRange(AugExpr): _template = cpptype("sycl::range< {dims} >", "<sycl/sycl.hpp>") def __init__(self, dims: int, const: bool = False, ref: bool = False): - dtype = self._template(dims=dims, const=const) - if ref: - dtype = Ref(dtype) + dtype = self._template(dims=dims, const=const, ref=ref) super().__init__(dtype) diff --git a/src/pystencilssfg/lang/cpp/std_mdspan.py b/src/pystencilssfg/lang/cpp/std_mdspan.py index 5f80ad6f15d71ecc12d6f5b55b8671e1dd1124d2..68308850f9dd7d5c3486265fb6c9b634d97c6fb7 100644 --- a/src/pystencilssfg/lang/cpp/std_mdspan.py +++ b/src/pystencilssfg/lang/cpp/std_mdspan.py @@ -11,7 +11,7 @@ from pystencils.types import ( from pystencilssfg.lang.expressions import AugExpr -from ...lang import SrcField, IFieldExtraction, cpptype, Ref, HeaderFile, ExprLike +from ...lang import SrcField, IFieldExtraction, cpptype, HeaderFile, ExprLike class StdMdspan(SrcField): @@ -111,11 +111,8 @@ class StdMdspan(SrcField): layout_policy = f"{self._namespace}::{layout_policy}" dtype = self._template( - T=T, extents=extents_str, layout_policy=layout_policy, const=const + T=T, extents=extents_str, layout_policy=layout_policy, const=const, ref=ref ) - - if ref: - dtype = Ref(dtype) super().__init__(dtype) self._extents_type = extents_str diff --git a/src/pystencilssfg/lang/cpp/std_span.py b/src/pystencilssfg/lang/cpp/std_span.py index 861a4c4bb1ea81b0cbaaef4cb683316274ab2edd..f161f4874f627fa8943f4e24c2a1082780259572 100644 --- a/src/pystencilssfg/lang/cpp/std_span.py +++ b/src/pystencilssfg/lang/cpp/std_span.py @@ -1,7 +1,7 @@ from pystencils.field import Field from pystencils.types import UserTypeSpec, create_type, PsType -from ...lang import SrcField, IFieldExtraction, AugExpr, cpptype, Ref +from ...lang import SrcField, IFieldExtraction, AugExpr, cpptype class StdSpan(SrcField): @@ -9,10 +9,7 @@ class StdSpan(SrcField): def __init__(self, T: UserTypeSpec, ref=False, const=False): T = create_type(T) - dtype = self._template(T=T, const=const) - if ref: - dtype = Ref(dtype) - + dtype = self._template(T=T, const=const, ref=ref) self._element_type = T super().__init__(dtype) diff --git a/src/pystencilssfg/lang/cpp/std_tuple.py b/src/pystencilssfg/lang/cpp/std_tuple.py index bbf2ba33b8f1a19081593501885d0dc935fc3055..58a3530b9e98e2c39e205fd7dac9845b4ff35bda 100644 --- a/src/pystencilssfg/lang/cpp/std_tuple.py +++ b/src/pystencilssfg/lang/cpp/std_tuple.py @@ -2,7 +2,7 @@ from typing import Sequence from pystencils.types import UserTypeSpec, create_type -from ...lang import SrcVector, AugExpr, cpptype, Ref +from ...lang import SrcVector, AugExpr, cpptype class StdTuple(SrcVector): @@ -18,10 +18,7 @@ class StdTuple(SrcVector): self._length = len(element_types) elt_type_strings = tuple(t.c_string() for t in self._element_types) - dtype = self._template(ts=", ".join(elt_type_strings), const=const) - if ref: - dtype = Ref(dtype) - + dtype = self._template(ts=", ".join(elt_type_strings), const=const, ref=ref) super().__init__(dtype) def extract_component(self, coordinate: int) -> AugExpr: diff --git a/src/pystencilssfg/lang/cpp/std_vector.py b/src/pystencilssfg/lang/cpp/std_vector.py index 5696b32dd55c6092db0ff298fdcbd79fa7df69f5..7e9291eab670a1f4f45996b60ea8b8b3e8f49ff4 100644 --- a/src/pystencilssfg/lang/cpp/std_vector.py +++ b/src/pystencilssfg/lang/cpp/std_vector.py @@ -1,7 +1,7 @@ from pystencils.field import Field from pystencils.types import UserTypeSpec, create_type, PsType -from ...lang import SrcField, SrcVector, AugExpr, IFieldExtraction, cpptype, Ref +from ...lang import SrcField, SrcVector, AugExpr, IFieldExtraction, cpptype class StdVector(SrcVector, SrcField): @@ -15,9 +15,7 @@ class StdVector(SrcVector, SrcField): const: bool = False, ): T = create_type(T) - dtype = self._template(T=T, const=const) - if ref: - dtype = Ref(dtype) + dtype = self._template(T=T, const=const, ref=ref) super().__init__(dtype) self._element_type = T diff --git a/src/pystencilssfg/lang/cpp/sycl_accessor.py b/src/pystencilssfg/lang/cpp/sycl_accessor.py index f01c53d24750dc3b4f0350134e9bccd6f8ea4c26..4bcad56cd4ef109faa66757075eeefb6a5b416d3 100644 --- a/src/pystencilssfg/lang/cpp/sycl_accessor.py +++ b/src/pystencilssfg/lang/cpp/sycl_accessor.py @@ -3,7 +3,7 @@ from ...lang import SrcField, IFieldExtraction from pystencils import Field from pystencils.types import UserTypeSpec, create_type -from ...lang import AugExpr, cpptype, Ref +from ...lang import AugExpr, cpptype class SyclAccessor(SrcField): @@ -29,9 +29,7 @@ class SyclAccessor(SrcField): T = create_type(T) if dimensions > 3: raise ValueError("sycl accessors can only have dims 1, 2 or 3") - dtype = self._template(T=T, dims=dimensions, const=const) - if ref: - dtype = Ref(dtype) + dtype = self._template(T=T, dims=dimensions, const=const, ref=ref) super().__init__(dtype) diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 53064bd13a201b0716545bf1e84677baa8df3be9..03818c6b481a5436deddb1a0266597e927e982eb 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -147,7 +147,7 @@ class VarExpr(DependentExpression): incls: Iterable[HeaderFile] match base_type: case CppType(): - incls = base_type.includes + incls = base_type.class_includes case _: incls = ( HeaderFile.parse(header) for header in var.dtype.required_headers @@ -408,7 +408,11 @@ def includes(expr: ExprLike) -> set[HeaderFile]: match expr: case SfgVar(_, dtype): - return set(HeaderFile.parse(h) for h in dtype.required_headers) + match dtype: + case CppType(): + return set(dtype.includes) + case _: + return set(HeaderFile.parse(h) for h in dtype.required_headers) case TypedSymbol(): return includes(asvar(expr)) case str(): diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py index e437ae6ea895a6e0db92c7e339aefbbf13e4967d..5f064d54fc9dcfd2f2178021e26a8c3446c50138 100644 --- a/src/pystencilssfg/lang/types.py +++ b/src/pystencilssfg/lang/types.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Any, Iterable, Sequence, Mapping from abc import ABC from dataclasses import dataclass +from itertools import chain import string @@ -38,7 +39,9 @@ class _TemplateArgFormatter(string.Formatter): def check_unused_args( self, used_args: set[int | str], args: Sequence, kwargs: Mapping[str, Any] ) -> None: - max_args_len: int = max((k for k in used_args if isinstance(k, int)), default=-1) + 1 + max_args_len: int = ( + max((k for k in used_args if isinstance(k, int)), default=-1) + 1 + ) if len(args) > max_args_len: raise ValueError( f"Too many positional arguments: Expected {max_args_len}, but got {len(args)}" @@ -56,21 +59,9 @@ class _TemplateArgs: class CppType(PsCustomType, ABC): - includes: frozenset[HeaderFile] + class_includes: frozenset[HeaderFile] template_string: str - def __new__( # type: ignore - cls, - *args, - ref: bool = False, - **kwargs, - ) -> CppType | Ref: - if ref: - obj = cls(*args, **kwargs) - return Ref(obj) - else: - return super().__new__(cls) - def __init__(self, *template_args, const: bool = False, **template_kwargs): # Support for cloning CppTypes if template_args and isinstance(template_args[0], _TemplateArgs): @@ -89,17 +80,34 @@ class CppType(PsCustomType, ABC): name = formatter.format(self.template_string, *pargs, **kwargs) self._targs = targs + self._includes = self.class_includes + + for arg in chain(pargs, kwargs.values()): + match arg: + case CppType(): + self._includes |= arg.includes + case PsType(): + self._includes |= { + HeaderFile.parse(h) for h in arg.required_headers + } + super().__init__(name, const=const) def __args__(self) -> tuple[Any, ...]: return (self._targs,) + @property + def includes(self) -> frozenset[HeaderFile]: + return self._includes + @property def required_headers(self) -> set[str]: - return set(str(h) for h in self.includes) + return set(str(h) for h in self.class_includes) -def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()): +def cpptype( + typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] = () +) -> type[CppType]: headers: list[str | HeaderFile] if isinstance(include, (str, HeaderFile)): @@ -111,9 +119,17 @@ def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] class TypeClass(CppType): template_string = typestr - includes = frozenset(HeaderFile.parse(h) for h in headers) + class_includes = frozenset(HeaderFile.parse(h) for h in headers) + + class TypeClassFactory(TypeClass): + def __new__(cls, *args, ref: bool = False, **kwargs): + obj = TypeClass(*args, **kwargs) + if ref: + return Ref(obj) + else: + return obj - return TypeClass + return TypeClassFactory class Ref(PsType): diff --git a/tests/generator_scripts/source/Conditionals.py b/tests/generator_scripts/source/Conditionals.py index d1088a96925bf52ab3634cf439538ce1de2b58a7..9016b73744f78fef504bb4f09b6742a630a6b12d 100644 --- a/tests/generator_scripts/source/Conditionals.py +++ b/tests/generator_scripts/source/Conditionals.py @@ -1,4 +1,5 @@ from pystencilssfg import SourceFileGenerator +from pystencils.types import PsCustomType with SourceFileGenerator() as sfg: sfg.namespace("gen") @@ -6,7 +7,7 @@ with SourceFileGenerator() as sfg: sfg.include("<iostream>") sfg.code(r"enum class Noodles { RIGATONI, RAMEN, SPAETZLE, SPAGHETTI };") - noodle = sfg.var("noodle", sfg.cpptype("Noodles")) + noodle = sfg.var("noodle", PsCustomType("Noodles")) sfg.function("printOpinion")( sfg.switch(noodle) diff --git a/tests/generator_scripts/source/SimpleClasses.py b/tests/generator_scripts/source/SimpleClasses.py index 64093f5744c61918bf505518e2c39dffbb525fad..454f1a26f8103f7c8b330f1a5b70b1b79f96ebbc 100644 --- a/tests/generator_scripts/source/SimpleClasses.py +++ b/tests/generator_scripts/source/SimpleClasses.py @@ -16,7 +16,7 @@ with SourceFileGenerator() as sfg: .init(y_)(y) .init(z_)(z), - sfg.method("getX", returns="const int64_t &", const=True, inline=True)( + sfg.method("getX", returns="const int64_t", const=True, inline=True)( "return this->x_;" ) ), diff --git a/tests/lang/test_types.py b/tests/lang/test_types.py index 35713ea7b4f4bb0797d8f9d414342d02d249dd5e..8bd0903b92c79f67cfcc8707585157174b80db08 100644 --- a/tests/lang/test_types.py +++ b/tests/lang/test_types.py @@ -11,7 +11,7 @@ def test_cpptypes(): vec_type = tclass(create_type("float32"), "std::allocator< float >") assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >" assert ( - tclass.includes + tclass.class_includes == vec_type.includes == {HeaderFile("vector", system_header=True)} ) @@ -65,3 +65,15 @@ def test_cpptype_ref(): vec_type = tclass(T=create_type("uint32"), ref=True) assert isinstance(vec_type, Ref) assert strip_ptr_ref(vec_type) == tclass(T=create_type("uint32")) + + +def test_cpptype_inherits_headers(): + optional_tclass = cpptype("std::optional< {T} >", "<optional>") + vec_tclass = cpptype("std::vector< {T} >", "<vector>") + + vec_type = vec_tclass(T=optional_tclass(T="int")) + assert ( + vec_type.includes + == optional_tclass.class_includes | vec_tclass.class_includes + == {HeaderFile.parse("<optional>"), HeaderFile.parse("<vector>")} + )