diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py index b3a634fbb63bb93d6c8e504030590df477cd200d..a41246d17bea7d61ad75fc526c9176810e44d11d 100644 --- a/src/pystencilssfg/lang/types.py +++ b/src/pystencilssfg/lang/types.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Any, Iterable from abc import ABC from pystencils.types import PsType, PsPointerType, PsCustomType @@ -50,16 +51,26 @@ def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] class TypeClass(CppType): includes = frozenset(HeaderFile.parse(h) for h in headers) - def __init__(self, *template_args, const: bool = False, **template_kwargs): + class TypeClassFactory(TypeClass): + def __new__( # type: ignore + cls, + *template_args, + const: bool = False, + ref: bool = False, + **template_kwargs, + ) -> PsType: template_args = tuple(_fixarg(arg) for arg in template_args) template_kwargs = { key: _fixarg(value) for key, value in template_kwargs.items() } name = typestr.format(*template_args, **template_kwargs) - super().__init__(name, const) + obj: PsType = TypeClass(name, const) + if ref: + obj = Ref(obj) + return obj - return TypeClass + return TypeClassFactory class Ref(PsType): diff --git a/tests/lang/test_types.py b/tests/lang/test_types.py index 44b9a7c23a273b703a5c4d49386258f72bb2f726..79e92403c86de6a647a6ee33c5c94ee54fa83144 100644 --- a/tests/lang/test_types.py +++ b/tests/lang/test_types.py @@ -1,5 +1,6 @@ -from pystencilssfg.lang import cpptype, HeaderFile +from pystencilssfg.lang import cpptype, HeaderFile, Ref, strip_ptr_ref from pystencils import create_type +from pystencils.types import constify, deconstify def test_cpptypes(): @@ -12,3 +13,23 @@ def test_cpptypes(): == vec_type.includes == {HeaderFile("vector", system_header=True)} ) + + assert deconstify(constify(vec_type)) == vec_type + + +def test_cpptype_const(): + tclass = cpptype("std::vector< {T} >", "<vector>") + + vec_type = tclass(T=create_type("uint32")) + assert constify(vec_type) == tclass(T=create_type("uint32"), const=True) + + vec_type = tclass(T=create_type("uint32"), const=True) + assert deconstify(vec_type) == tclass(T=create_type("uint32"), const=False) + + +def test_cpptype_ref(): + tclass = cpptype("std::vector< {T} >", "<vector>") + + vec_type = tclass(T=create_type("uint32"), ref=True) + assert isinstance(vec_type, Ref) + assert strip_ptr_ref(vec_type) == tclass(T=create_type("uint32"))