Skip to content
Snippets Groups Projects
Commit 8dd55ef2 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Fix CppType factory: Fix cloning in `(de)constify`, enable immediate creation of `ref`s

parent 11f15805
No related branches found
No related tags found
1 merge request!12Improve versatility and robustness of `cpptype`, and document it in the user guide
Pipeline #71501 failed
from __future__ import annotations
from typing import Any, Iterable from typing import Any, Iterable
from abc import ABC from abc import ABC
from pystencils.types import PsType, PsPointerType, PsCustomType from pystencils.types import PsType, PsPointerType, PsCustomType
...@@ -50,16 +51,26 @@ def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] ...@@ -50,16 +51,26 @@ def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile]
class TypeClass(CppType): class TypeClass(CppType):
includes = frozenset(HeaderFile.parse(h) for h in headers) includes = frozenset(HeaderFile.parse(h) for h in headers)
def __init__(self, *template_args, const: bool = False, **template_kwargs): class TypeClassFactory(TypeClass):
def __new__( # type: ignore
cls,
*template_args,
const: bool = False,
ref: bool = False,
**template_kwargs,
) -> PsType:
template_args = tuple(_fixarg(arg) for arg in template_args) template_args = tuple(_fixarg(arg) for arg in template_args)
template_kwargs = { template_kwargs = {
key: _fixarg(value) for key, value in template_kwargs.items() key: _fixarg(value) for key, value in template_kwargs.items()
} }
name = typestr.format(*template_args, **template_kwargs) name = typestr.format(*template_args, **template_kwargs)
super().__init__(name, const) obj: PsType = TypeClass(name, const)
if ref:
obj = Ref(obj)
return obj
return TypeClass return TypeClassFactory
class Ref(PsType): class Ref(PsType):
......
from pystencilssfg.lang import cpptype, HeaderFile from pystencilssfg.lang import cpptype, HeaderFile, Ref, strip_ptr_ref
from pystencils import create_type from pystencils import create_type
from pystencils.types import constify, deconstify
def test_cpptypes(): def test_cpptypes():
...@@ -12,3 +13,23 @@ def test_cpptypes(): ...@@ -12,3 +13,23 @@ def test_cpptypes():
== vec_type.includes == vec_type.includes
== {HeaderFile("vector", system_header=True)} == {HeaderFile("vector", system_header=True)}
) )
assert deconstify(constify(vec_type)) == vec_type
def test_cpptype_const():
tclass = cpptype("std::vector< {T} >", "<vector>")
vec_type = tclass(T=create_type("uint32"))
assert constify(vec_type) == tclass(T=create_type("uint32"), const=True)
vec_type = tclass(T=create_type("uint32"), const=True)
assert deconstify(vec_type) == tclass(T=create_type("uint32"), const=False)
def test_cpptype_ref():
tclass = cpptype("std::vector< {T} >", "<vector>")
vec_type = tclass(T=create_type("uint32"), ref=True)
assert isinstance(vec_type, Ref)
assert strip_ptr_ref(vec_type) == tclass(T=create_type("uint32"))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment