diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py index 5f064d54fc9dcfd2f2178021e26a8c3446c50138..4da3fa47b38e6a46b13b063ca37cb096f3af6ec0 100644 --- a/src/pystencilssfg/lang/types.py +++ b/src/pystencilssfg/lang/types.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Iterable, Sequence, Mapping +from typing import Any, Iterable, Sequence, Mapping, Callable from abc import ABC from dataclasses import dataclass from itertools import chain @@ -107,7 +107,7 @@ class CppType(PsCustomType, ABC): def cpptype( typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] = () -) -> type[CppType]: +) -> Callable[..., CppType | Ref]: headers: list[str | HeaderFile] if isinstance(include, (str, HeaderFile)): @@ -121,15 +121,14 @@ def cpptype( template_string = typestr class_includes = frozenset(HeaderFile.parse(h) for h in headers) - class TypeClassFactory(TypeClass): - def __new__(cls, *args, ref: bool = False, **kwargs): - obj = TypeClass(*args, **kwargs) - if ref: - return Ref(obj) - else: - return obj + def factory(*args, ref: bool = False, **kwargs): + obj = TypeClass(*args, **kwargs) + if ref: + return Ref(obj) + else: + return obj - return TypeClassFactory + return staticmethod(factory) class Ref(PsType): diff --git a/tests/lang/test_types.py b/tests/lang/test_types.py index 8bd0903b92c79f67cfcc8707585157174b80db08..62c4bb0bd8ecafc8cc5b2613122838dc56a80637 100644 --- a/tests/lang/test_types.py +++ b/tests/lang/test_types.py @@ -6,13 +6,12 @@ from pystencils.types import constify, deconstify def test_cpptypes(): - tclass = cpptype("std::vector< {}, {} >", "<vector>") + tfactory = cpptype("std::vector< {}, {} >", "<vector>") - vec_type = tclass(create_type("float32"), "std::allocator< float >") + vec_type = tfactory(create_type("float32"), "std::allocator< float >") assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >" assert ( - tclass.class_includes - == vec_type.includes + vec_type.includes == {HeaderFile("vector", system_header=True)} ) @@ -20,60 +19,59 @@ def test_cpptypes(): assert deconstify(constify(vec_type)) == vec_type # Duplicate Equality - assert tclass(create_type("float32"), "std::allocator< float >") == vec_type + assert tfactory(create_type("float32"), "std::allocator< float >") == vec_type # Not equal with different argument even though it produces the same string - assert tclass("float", "std::allocator< float >") != vec_type + assert tfactory("float", "std::allocator< float >") != vec_type # The same with keyword arguments - tclass = cpptype("std::vector< {T}, {Allocator} >", "<vector>") + tfactory = cpptype("std::vector< {T}, {Allocator} >", "<vector>") - vec_type = tclass(T=create_type("float32"), Allocator="std::allocator< float >") + vec_type = tfactory(T=create_type("float32"), Allocator="std::allocator< float >") assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >" assert deconstify(constify(vec_type)) == vec_type def test_cpptype_invalid_construction(): - tclass = cpptype("std::vector< {}, {Allocator} >", "<vector>") + tfactory = cpptype("std::vector< {}, {Allocator} >", "<vector>") with pytest.raises(IndexError): - _ = tclass(Allocator="SomeAlloc") + _ = tfactory(Allocator="SomeAlloc") with pytest.raises(KeyError): - _ = tclass("int") + _ = tfactory("int") with pytest.raises(ValueError, match="Too many positional arguments"): - _ = tclass("int", "bogus", Allocator="SomeAlloc") + _ = tfactory("int", "bogus", Allocator="SomeAlloc") with pytest.raises(ValueError, match="Extraneous keyword arguments"): - _ = tclass("int", Allocator="SomeAlloc", bogus=2) + _ = tfactory("int", Allocator="SomeAlloc", bogus=2) def test_cpptype_const(): - tclass = cpptype("std::vector< {T} >", "<vector>") + tfactory = cpptype("std::vector< {T} >", "<vector>") - vec_type = tclass(T=create_type("uint32")) - assert constify(vec_type) == tclass(T=create_type("uint32"), const=True) + vec_type = tfactory(T=create_type("uint32")) + assert constify(vec_type) == tfactory(T=create_type("uint32"), const=True) - vec_type = tclass(T=create_type("uint32"), const=True) - assert deconstify(vec_type) == tclass(T=create_type("uint32"), const=False) + vec_type = tfactory(T=create_type("uint32"), const=True) + assert deconstify(vec_type) == tfactory(T=create_type("uint32"), const=False) def test_cpptype_ref(): - tclass = cpptype("std::vector< {T} >", "<vector>") + tfactory = cpptype("std::vector< {T} >", "<vector>") - vec_type = tclass(T=create_type("uint32"), ref=True) + vec_type = tfactory(T=create_type("uint32"), ref=True) assert isinstance(vec_type, Ref) - assert strip_ptr_ref(vec_type) == tclass(T=create_type("uint32")) + assert strip_ptr_ref(vec_type) == tfactory(T=create_type("uint32")) def test_cpptype_inherits_headers(): - optional_tclass = cpptype("std::optional< {T} >", "<optional>") - vec_tclass = cpptype("std::vector< {T} >", "<vector>") - - vec_type = vec_tclass(T=optional_tclass(T="int")) - assert ( - vec_type.includes - == optional_tclass.class_includes | vec_tclass.class_includes - == {HeaderFile.parse("<optional>"), HeaderFile.parse("<vector>")} - ) + optional_tfactory = cpptype("std::optional< {T} >", "<optional>") + vec_tfactory = cpptype("std::vector< {T} >", "<vector>") + + vec_type = vec_tfactory(T=optional_tfactory(T="int")) + assert vec_type.includes == { + HeaderFile.parse("<optional>"), + HeaderFile.parse("<vector>"), + }