From 8dd55ef2edb78236e7f1c11be52a0000e2b815ab Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 20 Dec 2024 18:26:30 +0100 Subject: [PATCH] Fix CppType factory: Fix cloning in `(de)constify`, enable immediate creation of `ref`s --- src/pystencilssfg/lang/types.py | 17 ++++++++++++++--- tests/lang/test_types.py | 23 ++++++++++++++++++++++- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py index b3a634f..a41246d 100644 --- a/src/pystencilssfg/lang/types.py +++ b/src/pystencilssfg/lang/types.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Any, Iterable from abc import ABC from pystencils.types import PsType, PsPointerType, PsCustomType @@ -50,16 +51,26 @@ def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] class TypeClass(CppType): includes = frozenset(HeaderFile.parse(h) for h in headers) - def __init__(self, *template_args, const: bool = False, **template_kwargs): + class TypeClassFactory(TypeClass): + def __new__( # type: ignore + cls, + *template_args, + const: bool = False, + ref: bool = False, + **template_kwargs, + ) -> PsType: template_args = tuple(_fixarg(arg) for arg in template_args) template_kwargs = { key: _fixarg(value) for key, value in template_kwargs.items() } name = typestr.format(*template_args, **template_kwargs) - super().__init__(name, const) + obj: PsType = TypeClass(name, const) + if ref: + obj = Ref(obj) + return obj - return TypeClass + return TypeClassFactory class Ref(PsType): diff --git a/tests/lang/test_types.py b/tests/lang/test_types.py index 44b9a7c..79e9240 100644 --- a/tests/lang/test_types.py +++ b/tests/lang/test_types.py @@ -1,5 +1,6 @@ -from pystencilssfg.lang import cpptype, HeaderFile +from pystencilssfg.lang import cpptype, HeaderFile, Ref, strip_ptr_ref from pystencils import create_type +from pystencils.types import constify, deconstify def test_cpptypes(): @@ -12,3 +13,23 @@ def test_cpptypes(): == vec_type.includes == {HeaderFile("vector", system_header=True)} ) + + assert deconstify(constify(vec_type)) == vec_type + + +def test_cpptype_const(): + tclass = cpptype("std::vector< {T} >", "<vector>") + + vec_type = tclass(T=create_type("uint32")) + assert constify(vec_type) == tclass(T=create_type("uint32"), const=True) + + vec_type = tclass(T=create_type("uint32"), const=True) + assert deconstify(vec_type) == tclass(T=create_type("uint32"), const=False) + + +def test_cpptype_ref(): + tclass = cpptype("std::vector< {T} >", "<vector>") + + vec_type = tclass(T=create_type("uint32"), ref=True) + assert isinstance(vec_type, Ref) + assert strip_ptr_ref(vec_type) == tclass(T=create_type("uint32")) -- GitLab