Skip to content
Snippets Groups Projects
Commit 7d4f72c5 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

simplify type class factory as a function

parent 80f224a1
No related branches found
No related tags found
1 merge request!12Improve versatility and robustness of `cpptype`, and document it in the user guide
Pipeline #71534 failed
from __future__ import annotations from __future__ import annotations
from typing import Any, Iterable, Sequence, Mapping from typing import Any, Iterable, Sequence, Mapping, Callable
from abc import ABC from abc import ABC
from dataclasses import dataclass from dataclasses import dataclass
from itertools import chain from itertools import chain
...@@ -107,7 +107,7 @@ class CppType(PsCustomType, ABC): ...@@ -107,7 +107,7 @@ class CppType(PsCustomType, ABC):
def cpptype( def cpptype(
typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] = () typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()
) -> type[CppType]: ) -> Callable[..., CppType | Ref]:
headers: list[str | HeaderFile] headers: list[str | HeaderFile]
if isinstance(include, (str, HeaderFile)): if isinstance(include, (str, HeaderFile)):
...@@ -121,15 +121,14 @@ def cpptype( ...@@ -121,15 +121,14 @@ def cpptype(
template_string = typestr template_string = typestr
class_includes = frozenset(HeaderFile.parse(h) for h in headers) class_includes = frozenset(HeaderFile.parse(h) for h in headers)
class TypeClassFactory(TypeClass): def factory(*args, ref: bool = False, **kwargs):
def __new__(cls, *args, ref: bool = False, **kwargs): obj = TypeClass(*args, **kwargs)
obj = TypeClass(*args, **kwargs) if ref:
if ref: return Ref(obj)
return Ref(obj) else:
else: return obj
return obj
return TypeClassFactory return staticmethod(factory)
class Ref(PsType): class Ref(PsType):
......
...@@ -6,13 +6,12 @@ from pystencils.types import constify, deconstify ...@@ -6,13 +6,12 @@ from pystencils.types import constify, deconstify
def test_cpptypes(): def test_cpptypes():
tclass = cpptype("std::vector< {}, {} >", "<vector>") tfactory = cpptype("std::vector< {}, {} >", "<vector>")
vec_type = tclass(create_type("float32"), "std::allocator< float >") vec_type = tfactory(create_type("float32"), "std::allocator< float >")
assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >" assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >"
assert ( assert (
tclass.class_includes vec_type.includes
== vec_type.includes
== {HeaderFile("vector", system_header=True)} == {HeaderFile("vector", system_header=True)}
) )
...@@ -20,60 +19,59 @@ def test_cpptypes(): ...@@ -20,60 +19,59 @@ def test_cpptypes():
assert deconstify(constify(vec_type)) == vec_type assert deconstify(constify(vec_type)) == vec_type
# Duplicate Equality # Duplicate Equality
assert tclass(create_type("float32"), "std::allocator< float >") == vec_type assert tfactory(create_type("float32"), "std::allocator< float >") == vec_type
# Not equal with different argument even though it produces the same string # Not equal with different argument even though it produces the same string
assert tclass("float", "std::allocator< float >") != vec_type assert tfactory("float", "std::allocator< float >") != vec_type
# The same with keyword arguments # The same with keyword arguments
tclass = cpptype("std::vector< {T}, {Allocator} >", "<vector>") tfactory = cpptype("std::vector< {T}, {Allocator} >", "<vector>")
vec_type = tclass(T=create_type("float32"), Allocator="std::allocator< float >") vec_type = tfactory(T=create_type("float32"), Allocator="std::allocator< float >")
assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >" assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >"
assert deconstify(constify(vec_type)) == vec_type assert deconstify(constify(vec_type)) == vec_type
def test_cpptype_invalid_construction(): def test_cpptype_invalid_construction():
tclass = cpptype("std::vector< {}, {Allocator} >", "<vector>") tfactory = cpptype("std::vector< {}, {Allocator} >", "<vector>")
with pytest.raises(IndexError): with pytest.raises(IndexError):
_ = tclass(Allocator="SomeAlloc") _ = tfactory(Allocator="SomeAlloc")
with pytest.raises(KeyError): with pytest.raises(KeyError):
_ = tclass("int") _ = tfactory("int")
with pytest.raises(ValueError, match="Too many positional arguments"): with pytest.raises(ValueError, match="Too many positional arguments"):
_ = tclass("int", "bogus", Allocator="SomeAlloc") _ = tfactory("int", "bogus", Allocator="SomeAlloc")
with pytest.raises(ValueError, match="Extraneous keyword arguments"): with pytest.raises(ValueError, match="Extraneous keyword arguments"):
_ = tclass("int", Allocator="SomeAlloc", bogus=2) _ = tfactory("int", Allocator="SomeAlloc", bogus=2)
def test_cpptype_const(): def test_cpptype_const():
tclass = cpptype("std::vector< {T} >", "<vector>") tfactory = cpptype("std::vector< {T} >", "<vector>")
vec_type = tclass(T=create_type("uint32")) vec_type = tfactory(T=create_type("uint32"))
assert constify(vec_type) == tclass(T=create_type("uint32"), const=True) assert constify(vec_type) == tfactory(T=create_type("uint32"), const=True)
vec_type = tclass(T=create_type("uint32"), const=True) vec_type = tfactory(T=create_type("uint32"), const=True)
assert deconstify(vec_type) == tclass(T=create_type("uint32"), const=False) assert deconstify(vec_type) == tfactory(T=create_type("uint32"), const=False)
def test_cpptype_ref(): def test_cpptype_ref():
tclass = cpptype("std::vector< {T} >", "<vector>") tfactory = cpptype("std::vector< {T} >", "<vector>")
vec_type = tclass(T=create_type("uint32"), ref=True) vec_type = tfactory(T=create_type("uint32"), ref=True)
assert isinstance(vec_type, Ref) assert isinstance(vec_type, Ref)
assert strip_ptr_ref(vec_type) == tclass(T=create_type("uint32")) assert strip_ptr_ref(vec_type) == tfactory(T=create_type("uint32"))
def test_cpptype_inherits_headers(): def test_cpptype_inherits_headers():
optional_tclass = cpptype("std::optional< {T} >", "<optional>") optional_tfactory = cpptype("std::optional< {T} >", "<optional>")
vec_tclass = cpptype("std::vector< {T} >", "<vector>") vec_tfactory = cpptype("std::vector< {T} >", "<vector>")
vec_type = vec_tclass(T=optional_tclass(T="int")) vec_type = vec_tfactory(T=optional_tfactory(T="int"))
assert ( assert vec_type.includes == {
vec_type.includes HeaderFile.parse("<optional>"),
== optional_tclass.class_includes | vec_tclass.class_includes HeaderFile.parse("<vector>"),
== {HeaderFile.parse("<optional>"), HeaderFile.parse("<vector>")} }
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment