Skip to content
Snippets Groups Projects
Commit d6eb671a authored by Frederik Hennig's avatar Frederik Hennig
Browse files

refactor types hashing and equality

parent 191cc207
No related branches found
No related tags found
2 merge requests!379Type System Refactor,!374Uniqueness of Data Type Instances
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import final, TypeVar, Any, Sequence from typing import final, TypeVar, Any, Sequence, cast
from dataclasses import dataclass from dataclasses import dataclass
from copy import copy from copy import copy
...@@ -46,11 +46,22 @@ class PsType(ABC): ...@@ -46,11 +46,22 @@ class PsType(ABC):
return None return None
# ------------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------------
# Internal virtual operations # Internal operations
# ------------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------------
def _base_equal(self, other: PsType) -> bool: @abstractmethod
return type(self) is type(other) and self._const == other._const def __args__(self) -> tuple[Any, ...]:
"""Arguments to this type.
The tuple returned by this method is used to serialize, deserialize, and check equality of types.
For each instantiable subclass ``MyType`` of ``PsType``, the following must hold:
```
t = MyType(< arguments >)
assert MyType(*t.__args__()) == t
```
"""
pass
def _const_string(self) -> str: def _const_string(self) -> str:
return "const " if self._const else "" return "const " if self._const else ""
...@@ -63,16 +74,21 @@ class PsType(ABC): ...@@ -63,16 +74,21 @@ class PsType(ABC):
# Dunder Methods # Dunder Methods
# ------------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------------
@abstractmethod
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
pass if self is other:
return True
if type(self) is not type(other):
return False
other = cast(PsType, other)
return self.__args__() == other.__args__()
def __str__(self) -> str: def __str__(self) -> str:
return self.c_string() return self.c_string()
@abstractmethod
def __hash__(self) -> int: def __hash__(self) -> int:
pass return hash((type(self), self.__args__()))
class PsCustomType(PsType): class PsCustomType(PsType):
...@@ -92,13 +108,13 @@ class PsCustomType(PsType): ...@@ -92,13 +108,13 @@ class PsCustomType(PsType):
def name(self) -> str: def name(self) -> str:
return self._name return self._name
def __eq__(self, other: object) -> bool: def __args__(self) -> tuple[Any, ...]:
if not isinstance(other, PsCustomType): """
return False >>> t = PsCustomType("std::vector< int >")
return self._base_equal(other) and self._name == other._name >>> t == PsCustomType(*t.__args__())
True
def __hash__(self) -> int: """
return hash(("PsCustomType", self._name, self._const)) return (self._name,)
def c_string(self) -> str: def c_string(self) -> str:
return f"{self._const_string()} {self._name}" return f"{self._const_string()} {self._name}"
...@@ -142,18 +158,18 @@ class PsPointerType(PsDereferencableType): ...@@ -142,18 +158,18 @@ class PsPointerType(PsDereferencableType):
super().__init__(base_type, const) super().__init__(base_type, const)
self._restrict = restrict self._restrict = restrict
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsPointerType(PsBoolType(), const=True)
>>> t == PsPointerType(*t.__args__())
True
"""
return (self._base_type, self._const, self._restrict)
@property @property
def restrict(self) -> bool: def restrict(self) -> bool:
return self._restrict return self._restrict
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsPointerType):
return False
return self._base_equal(other) and self._base_type == other._base_type
def __hash__(self) -> int:
return hash(("PsPointerType", self._base_type, self._restrict, self._const))
def c_string(self) -> str: def c_string(self) -> str:
base_str = self._base_type.c_string() base_str = self._base_type.c_string()
restrict_str = " RESTRICT" if self._restrict else "" restrict_str = " RESTRICT" if self._restrict else ""
...@@ -172,6 +188,14 @@ class PsArrayType(PsDereferencableType): ...@@ -172,6 +188,14 @@ class PsArrayType(PsDereferencableType):
self._length = length self._length = length
super().__init__(base_type, const) super().__init__(base_type, const)
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsArrayType(PsBoolType(), 13, const=True)
>>> t == PsArrayType(*t.__args__())
True
"""
return (self._base_type, self._length, self._const)
@property @property
def length(self) -> int | None: def length(self) -> int | None:
return self._length return self._length
...@@ -179,19 +203,6 @@ class PsArrayType(PsDereferencableType): ...@@ -179,19 +203,6 @@ class PsArrayType(PsDereferencableType):
def c_string(self) -> str: def c_string(self) -> str:
return f"{self._base_type.c_string()} [{str(self._length) if self._length is not None else ''}]" return f"{self._base_type.c_string()} [{str(self._length) if self._length is not None else ''}]"
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsArrayType):
return False
return (
self._base_equal(other)
and self._base_type == other._base_type
and self._length == other._length
)
def __hash__(self) -> int:
return hash(("PsArrayType", self._base_type, self._length, self._const))
def __repr__(self) -> str: def __repr__(self) -> str:
return f"PsArrayType(element_type={repr(self._base_type)}, size={self._length}, const={self._const})" return f"PsArrayType(element_type={repr(self._base_type)}, size={self._length}, const={self._const})"
...@@ -229,6 +240,14 @@ class PsStructType(PsType): ...@@ -229,6 +240,14 @@ class PsStructType(PsType):
raise ValueError(f"Duplicate struct member name: {member.name}") raise ValueError(f"Duplicate struct member name: {member.name}")
names.add(member.name) names.add(member.name)
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsStructType([("idx", PsSignedIntegerType(32)), ("val", PsBoolType())], "sname")
>>> t == PsStructType(*t.__args__())
True
"""
return (self._members, self._name, self._const)
@property @property
def members(self) -> tuple[PsStructType.Member, ...]: def members(self) -> tuple[PsStructType.Member, ...]:
return self._members return self._members
...@@ -276,19 +295,6 @@ class PsStructType(PsType): ...@@ -276,19 +295,6 @@ class PsStructType(PsType):
else: else:
return self._name return self._name
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsStructType):
return False
return (
self._base_equal(other)
and self._name == other._name
and self._members == other._members
)
def __hash__(self) -> int:
return hash(("PsStructTupe", self._name, self._members, self._const))
def __repr__(self) -> str: def __repr__(self) -> str:
members = ", ".join(f"{m.dtype} {m.name}" for m in self._members) members = ", ".join(f"{m.dtype} {m.name}" for m in self._members)
name = "<anonymous>" if self.anonymous else f"name={self._name}" name = "<anonymous>" if self.anonymous else f"name={self._name}"
...@@ -386,6 +392,14 @@ class PsVectorType(PsNumericType): ...@@ -386,6 +392,14 @@ class PsVectorType(PsNumericType):
self._vector_entries = vector_entries self._vector_entries = vector_entries
self._scalar_type = constify(scalar_type) if const else deconstify(scalar_type) self._scalar_type = constify(scalar_type) if const else deconstify(scalar_type)
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsVectorType(PsBoolType(), 8, True)
>>> t == PsVectorType(*t.__args__())
True
"""
return (self._scalar_type, self._vector_entries, self._const)
@property @property
def scalar_type(self) -> PsScalarType: def scalar_type(self) -> PsScalarType:
return self._scalar_type return self._scalar_type
...@@ -437,21 +451,6 @@ class PsVectorType(PsNumericType): ...@@ -437,21 +451,6 @@ class PsVectorType(PsNumericType):
[element] * self._vector_entries, dtype=self.scalar_type.numpy_dtype [element] * self._vector_entries, dtype=self.scalar_type.numpy_dtype
) )
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsVectorType):
return False
return (
self._base_equal(other)
and self._scalar_type == other._scalar_type
and self._vector_entries == other._vector_entries
)
def __hash__(self) -> int:
return hash(
("PsVectorType", self._scalar_type, self._vector_entries, self._const)
)
def c_string(self) -> str: def c_string(self) -> str:
raise PsTypeError("Cannot retrieve C type string for generic vector types.") raise PsTypeError("Cannot retrieve C type string for generic vector types.")
...@@ -473,6 +472,14 @@ class PsBoolType(PsScalarType): ...@@ -473,6 +472,14 @@ class PsBoolType(PsScalarType):
def __init__(self, const: bool = False): def __init__(self, const: bool = False):
super().__init__(const) super().__init__(const)
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsBoolType(True)
>>> t == PsBoolType(*t.__args__())
True
"""
return (self._const,)
@property @property
def width(self) -> int: def width(self) -> int:
return 8 return 8
...@@ -506,16 +513,7 @@ class PsBoolType(PsScalarType): ...@@ -506,16 +513,7 @@ class PsBoolType(PsScalarType):
def c_string(self) -> str: def c_string(self) -> str:
return "bool" return "bool"
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsBoolType):
return False
return self._base_equal(other)
def __hash__(self) -> int:
return hash(("PsBoolType", self._const))
class PsIntegerType(PsScalarType, ABC): class PsIntegerType(PsScalarType, ABC):
"""Signed and unsigned integer types. """Signed and unsigned integer types.
...@@ -574,20 +572,7 @@ class PsIntegerType(PsScalarType, ABC): ...@@ -574,20 +572,7 @@ class PsIntegerType(PsScalarType, ABC):
return np_type(value) return np_type(value)
raise PsTypeError(f"Could not interpret {value} as {repr(self)}") raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsIntegerType):
return False
return (
self._base_equal(other)
and self._width == other._width
and self._signed == other._signed
)
def __hash__(self) -> int:
return hash(("PsIntegerType", self._width, self._signed, self._const))
def c_string(self) -> str: def c_string(self) -> str:
prefix = "" if self._signed else "u" prefix = "" if self._signed else "u"
return f"{self._const_string()}{prefix}int{self._width}_t" return f"{self._const_string()}{prefix}int{self._width}_t"
...@@ -612,6 +597,14 @@ class PsSignedIntegerType(PsIntegerType): ...@@ -612,6 +597,14 @@ class PsSignedIntegerType(PsIntegerType):
def __init__(self, width: int, const: bool = False): def __init__(self, width: int, const: bool = False):
super().__init__(width, True, const) super().__init__(width, True, const)
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsSignedIntegerType(32, True)
>>> t == PsSignedIntegerType(*t.__args__())
True
"""
return (self._width, self._const)
@final @final
class PsUnsignedIntegerType(PsIntegerType): class PsUnsignedIntegerType(PsIntegerType):
...@@ -629,6 +622,14 @@ class PsUnsignedIntegerType(PsIntegerType): ...@@ -629,6 +622,14 @@ class PsUnsignedIntegerType(PsIntegerType):
def __init__(self, width: int, const: bool = False): def __init__(self, width: int, const: bool = False):
super().__init__(width, False, const) super().__init__(width, False, const)
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsUnsignedIntegerType(32, True)
>>> t == PsUnsignedIntegerType(*t.__args__())
True
"""
return (self._width, self._const)
@final @final
class PsIeeeFloatType(PsScalarType): class PsIeeeFloatType(PsScalarType):
...@@ -653,6 +654,14 @@ class PsIeeeFloatType(PsScalarType): ...@@ -653,6 +654,14 @@ class PsIeeeFloatType(PsScalarType):
super().__init__(const) super().__init__(const)
self._width = width self._width = width
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsIeeeFloatType(32, True)
>>> t == PsIeeeFloatType(*t.__args__())
True
"""
return (self._width, self._const)
@property @property
def width(self) -> int: def width(self) -> int:
return self._width return self._width
...@@ -698,14 +707,6 @@ class PsIeeeFloatType(PsScalarType): ...@@ -698,14 +707,6 @@ class PsIeeeFloatType(PsScalarType):
raise PsTypeError(f"Could not interpret {value} as {repr(self)}") raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsIeeeFloatType):
return False
return self._base_equal(other) and self._width == other._width
def __hash__(self) -> int:
return hash(("PsIeeeFloatType", self._width, self._const))
def c_string(self) -> str: def c_string(self) -> str:
match self._width: match self._width:
case 16: case 16:
......
...@@ -22,9 +22,10 @@ def test_parsing_positive(): ...@@ -22,9 +22,10 @@ def test_parsing_positive():
assert create_type("const uint32_t * restrict") == Ptr( assert create_type("const uint32_t * restrict") == Ptr(
UInt(32, const=True), restrict=True UInt(32, const=True), restrict=True
) )
assert create_type("float * * const") == Ptr(Ptr(Fp(32)), const=True) assert create_type("float * * const") == Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=False)
assert create_type("uint16 * const") == Ptr(UInt(16), const=True) assert create_type("float * * restrict const") == Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=True)
assert create_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True) assert create_type("uint16 * const") == Ptr(UInt(16), const=True, restrict=False)
assert create_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True, restrict=False)
def test_parsing_negative(): def test_parsing_negative():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment