Skip to content
Snippets Groups Projects
Commit d6eb671a authored by Frederik Hennig's avatar Frederik Hennig
Browse files

refactor types hashing and equality

parent 191cc207
2 merge requests!379Type System Refactor,!374Uniqueness of Data Type Instances
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import final, TypeVar, Any, Sequence
from typing import final, TypeVar, Any, Sequence, cast
from dataclasses import dataclass
from copy import copy
......@@ -46,11 +46,22 @@ class PsType(ABC):
return None
# -------------------------------------------------------------------------------------------
# Internal virtual operations
# Internal operations
# -------------------------------------------------------------------------------------------
def _base_equal(self, other: PsType) -> bool:
return type(self) is type(other) and self._const == other._const
@abstractmethod
def __args__(self) -> tuple[Any, ...]:
"""Arguments to this type.
The tuple returned by this method is used to serialize, deserialize, and check equality of types.
For each instantiable subclass ``MyType`` of ``PsType``, the following must hold:
```
t = MyType(< arguments >)
assert MyType(*t.__args__()) == t
```
"""
pass
def _const_string(self) -> str:
return "const " if self._const else ""
......@@ -63,16 +74,21 @@ class PsType(ABC):
# Dunder Methods
# -------------------------------------------------------------------------------------------
@abstractmethod
def __eq__(self, other: object) -> bool:
pass
if self is other:
return True
if type(self) is not type(other):
return False
other = cast(PsType, other)
return self.__args__() == other.__args__()
def __str__(self) -> str:
return self.c_string()
@abstractmethod
def __hash__(self) -> int:
pass
return hash((type(self), self.__args__()))
class PsCustomType(PsType):
......@@ -92,13 +108,13 @@ class PsCustomType(PsType):
def name(self) -> str:
return self._name
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsCustomType):
return False
return self._base_equal(other) and self._name == other._name
def __hash__(self) -> int:
return hash(("PsCustomType", self._name, self._const))
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsCustomType("std::vector< int >")
>>> t == PsCustomType(*t.__args__())
True
"""
return (self._name,)
def c_string(self) -> str:
return f"{self._const_string()} {self._name}"
......@@ -142,18 +158,18 @@ class PsPointerType(PsDereferencableType):
super().__init__(base_type, const)
self._restrict = restrict
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsPointerType(PsBoolType(), const=True)
>>> t == PsPointerType(*t.__args__())
True
"""
return (self._base_type, self._const, self._restrict)
@property
def restrict(self) -> bool:
return self._restrict
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsPointerType):
return False
return self._base_equal(other) and self._base_type == other._base_type
def __hash__(self) -> int:
return hash(("PsPointerType", self._base_type, self._restrict, self._const))
def c_string(self) -> str:
base_str = self._base_type.c_string()
restrict_str = " RESTRICT" if self._restrict else ""
......@@ -172,6 +188,14 @@ class PsArrayType(PsDereferencableType):
self._length = length
super().__init__(base_type, const)
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsArrayType(PsBoolType(), 13, const=True)
>>> t == PsArrayType(*t.__args__())
True
"""
return (self._base_type, self._length, self._const)
@property
def length(self) -> int | None:
return self._length
......@@ -179,19 +203,6 @@ class PsArrayType(PsDereferencableType):
def c_string(self) -> str:
return f"{self._base_type.c_string()} [{str(self._length) if self._length is not None else ''}]"
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsArrayType):
return False
return (
self._base_equal(other)
and self._base_type == other._base_type
and self._length == other._length
)
def __hash__(self) -> int:
return hash(("PsArrayType", self._base_type, self._length, self._const))
def __repr__(self) -> str:
return f"PsArrayType(element_type={repr(self._base_type)}, size={self._length}, const={self._const})"
......@@ -229,6 +240,14 @@ class PsStructType(PsType):
raise ValueError(f"Duplicate struct member name: {member.name}")
names.add(member.name)
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsStructType([("idx", PsSignedIntegerType(32)), ("val", PsBoolType())], "sname")
>>> t == PsStructType(*t.__args__())
True
"""
return (self._members, self._name, self._const)
@property
def members(self) -> tuple[PsStructType.Member, ...]:
return self._members
......@@ -276,19 +295,6 @@ class PsStructType(PsType):
else:
return self._name
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsStructType):
return False
return (
self._base_equal(other)
and self._name == other._name
and self._members == other._members
)
def __hash__(self) -> int:
return hash(("PsStructTupe", self._name, self._members, self._const))
def __repr__(self) -> str:
members = ", ".join(f"{m.dtype} {m.name}" for m in self._members)
name = "<anonymous>" if self.anonymous else f"name={self._name}"
......@@ -386,6 +392,14 @@ class PsVectorType(PsNumericType):
self._vector_entries = vector_entries
self._scalar_type = constify(scalar_type) if const else deconstify(scalar_type)
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsVectorType(PsBoolType(), 8, True)
>>> t == PsVectorType(*t.__args__())
True
"""
return (self._scalar_type, self._vector_entries, self._const)
@property
def scalar_type(self) -> PsScalarType:
return self._scalar_type
......@@ -437,21 +451,6 @@ class PsVectorType(PsNumericType):
[element] * self._vector_entries, dtype=self.scalar_type.numpy_dtype
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsVectorType):
return False
return (
self._base_equal(other)
and self._scalar_type == other._scalar_type
and self._vector_entries == other._vector_entries
)
def __hash__(self) -> int:
return hash(
("PsVectorType", self._scalar_type, self._vector_entries, self._const)
)
def c_string(self) -> str:
raise PsTypeError("Cannot retrieve C type string for generic vector types.")
......@@ -473,6 +472,14 @@ class PsBoolType(PsScalarType):
def __init__(self, const: bool = False):
super().__init__(const)
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsBoolType(True)
>>> t == PsBoolType(*t.__args__())
True
"""
return (self._const,)
@property
def width(self) -> int:
return 8
......@@ -506,16 +513,7 @@ class PsBoolType(PsScalarType):
def c_string(self) -> str:
return "bool"
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsBoolType):
return False
return self._base_equal(other)
def __hash__(self) -> int:
return hash(("PsBoolType", self._const))
class PsIntegerType(PsScalarType, ABC):
"""Signed and unsigned integer types.
......@@ -574,20 +572,7 @@ class PsIntegerType(PsScalarType, ABC):
return np_type(value)
raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsIntegerType):
return False
return (
self._base_equal(other)
and self._width == other._width
and self._signed == other._signed
)
def __hash__(self) -> int:
return hash(("PsIntegerType", self._width, self._signed, self._const))
def c_string(self) -> str:
prefix = "" if self._signed else "u"
return f"{self._const_string()}{prefix}int{self._width}_t"
......@@ -612,6 +597,14 @@ class PsSignedIntegerType(PsIntegerType):
def __init__(self, width: int, const: bool = False):
super().__init__(width, True, const)
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsSignedIntegerType(32, True)
>>> t == PsSignedIntegerType(*t.__args__())
True
"""
return (self._width, self._const)
@final
class PsUnsignedIntegerType(PsIntegerType):
......@@ -629,6 +622,14 @@ class PsUnsignedIntegerType(PsIntegerType):
def __init__(self, width: int, const: bool = False):
super().__init__(width, False, const)
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsUnsignedIntegerType(32, True)
>>> t == PsUnsignedIntegerType(*t.__args__())
True
"""
return (self._width, self._const)
@final
class PsIeeeFloatType(PsScalarType):
......@@ -653,6 +654,14 @@ class PsIeeeFloatType(PsScalarType):
super().__init__(const)
self._width = width
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsIeeeFloatType(32, True)
>>> t == PsIeeeFloatType(*t.__args__())
True
"""
return (self._width, self._const)
@property
def width(self) -> int:
return self._width
......@@ -698,14 +707,6 @@ class PsIeeeFloatType(PsScalarType):
raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsIeeeFloatType):
return False
return self._base_equal(other) and self._width == other._width
def __hash__(self) -> int:
return hash(("PsIeeeFloatType", self._width, self._const))
def c_string(self) -> str:
match self._width:
case 16:
......
......@@ -22,9 +22,10 @@ def test_parsing_positive():
assert create_type("const uint32_t * restrict") == Ptr(
UInt(32, const=True), restrict=True
)
assert create_type("float * * const") == Ptr(Ptr(Fp(32)), const=True)
assert create_type("uint16 * const") == Ptr(UInt(16), const=True)
assert create_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True)
assert create_type("float * * const") == Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=False)
assert create_type("float * * restrict const") == Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=True)
assert create_type("uint16 * const") == Ptr(UInt(16), const=True, restrict=False)
assert create_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True, restrict=False)
def test_parsing_negative():
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment