From d6eb671a178a4e5d2d1038bd636b1832694a9e71 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 3 Apr 2024 12:06:38 +0200 Subject: [PATCH] refactor types hashing and equality --- src/pystencils/types/basic_types.py | 193 ++++++++++++++-------------- tests/nbackend/types/test_types.py | 7 +- 2 files changed, 101 insertions(+), 99 deletions(-) diff --git a/src/pystencils/types/basic_types.py b/src/pystencils/types/basic_types.py index 3678ea126..ebfe9d610 100644 --- a/src/pystencils/types/basic_types.py +++ b/src/pystencils/types/basic_types.py @@ -1,6 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import final, TypeVar, Any, Sequence +from typing import final, TypeVar, Any, Sequence, cast from dataclasses import dataclass from copy import copy @@ -46,11 +46,22 @@ class PsType(ABC): return None # ------------------------------------------------------------------------------------------- - # Internal virtual operations + # Internal operations # ------------------------------------------------------------------------------------------- - def _base_equal(self, other: PsType) -> bool: - return type(self) is type(other) and self._const == other._const + @abstractmethod + def __args__(self) -> tuple[Any, ...]: + """Arguments to this type. + + The tuple returned by this method is used to serialize, deserialize, and check equality of types. + For each instantiable subclass ``MyType`` of ``PsType``, the following must hold: + + ``` + t = MyType(< arguments >) + assert MyType(*t.__args__()) == t + ``` + """ + pass def _const_string(self) -> str: return "const " if self._const else "" @@ -63,16 +74,21 @@ class PsType(ABC): # Dunder Methods # ------------------------------------------------------------------------------------------- - @abstractmethod def __eq__(self, other: object) -> bool: - pass + if self is other: + return True + + if type(self) is not type(other): + return False + + other = cast(PsType, other) + return self.__args__() == other.__args__() def __str__(self) -> str: return self.c_string() - @abstractmethod def __hash__(self) -> int: - pass + return hash((type(self), self.__args__())) class PsCustomType(PsType): @@ -92,13 +108,13 @@ class PsCustomType(PsType): def name(self) -> str: return self._name - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsCustomType): - return False - return self._base_equal(other) and self._name == other._name - - def __hash__(self) -> int: - return hash(("PsCustomType", self._name, self._const)) + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsCustomType("std::vector< int >") + >>> t == PsCustomType(*t.__args__()) + True + """ + return (self._name,) def c_string(self) -> str: return f"{self._const_string()} {self._name}" @@ -142,18 +158,18 @@ class PsPointerType(PsDereferencableType): super().__init__(base_type, const) self._restrict = restrict + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsPointerType(PsBoolType(), const=True) + >>> t == PsPointerType(*t.__args__()) + True + """ + return (self._base_type, self._const, self._restrict) + @property def restrict(self) -> bool: return self._restrict - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsPointerType): - return False - return self._base_equal(other) and self._base_type == other._base_type - - def __hash__(self) -> int: - return hash(("PsPointerType", self._base_type, self._restrict, self._const)) - def c_string(self) -> str: base_str = self._base_type.c_string() restrict_str = " RESTRICT" if self._restrict else "" @@ -172,6 +188,14 @@ class PsArrayType(PsDereferencableType): self._length = length super().__init__(base_type, const) + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsArrayType(PsBoolType(), 13, const=True) + >>> t == PsArrayType(*t.__args__()) + True + """ + return (self._base_type, self._length, self._const) + @property def length(self) -> int | None: return self._length @@ -179,19 +203,6 @@ class PsArrayType(PsDereferencableType): def c_string(self) -> str: return f"{self._base_type.c_string()} [{str(self._length) if self._length is not None else ''}]" - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsArrayType): - return False - - return ( - self._base_equal(other) - and self._base_type == other._base_type - and self._length == other._length - ) - - def __hash__(self) -> int: - return hash(("PsArrayType", self._base_type, self._length, self._const)) - def __repr__(self) -> str: return f"PsArrayType(element_type={repr(self._base_type)}, size={self._length}, const={self._const})" @@ -229,6 +240,14 @@ class PsStructType(PsType): raise ValueError(f"Duplicate struct member name: {member.name}") names.add(member.name) + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsStructType([("idx", PsSignedIntegerType(32)), ("val", PsBoolType())], "sname") + >>> t == PsStructType(*t.__args__()) + True + """ + return (self._members, self._name, self._const) + @property def members(self) -> tuple[PsStructType.Member, ...]: return self._members @@ -276,19 +295,6 @@ class PsStructType(PsType): else: return self._name - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsStructType): - return False - - return ( - self._base_equal(other) - and self._name == other._name - and self._members == other._members - ) - - def __hash__(self) -> int: - return hash(("PsStructTupe", self._name, self._members, self._const)) - def __repr__(self) -> str: members = ", ".join(f"{m.dtype} {m.name}" for m in self._members) name = "<anonymous>" if self.anonymous else f"name={self._name}" @@ -386,6 +392,14 @@ class PsVectorType(PsNumericType): self._vector_entries = vector_entries self._scalar_type = constify(scalar_type) if const else deconstify(scalar_type) + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsVectorType(PsBoolType(), 8, True) + >>> t == PsVectorType(*t.__args__()) + True + """ + return (self._scalar_type, self._vector_entries, self._const) + @property def scalar_type(self) -> PsScalarType: return self._scalar_type @@ -437,21 +451,6 @@ class PsVectorType(PsNumericType): [element] * self._vector_entries, dtype=self.scalar_type.numpy_dtype ) - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsVectorType): - return False - - return ( - self._base_equal(other) - and self._scalar_type == other._scalar_type - and self._vector_entries == other._vector_entries - ) - - def __hash__(self) -> int: - return hash( - ("PsVectorType", self._scalar_type, self._vector_entries, self._const) - ) - def c_string(self) -> str: raise PsTypeError("Cannot retrieve C type string for generic vector types.") @@ -473,6 +472,14 @@ class PsBoolType(PsScalarType): def __init__(self, const: bool = False): super().__init__(const) + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsBoolType(True) + >>> t == PsBoolType(*t.__args__()) + True + """ + return (self._const,) + @property def width(self) -> int: return 8 @@ -506,16 +513,7 @@ class PsBoolType(PsScalarType): def c_string(self) -> str: return "bool" - - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsBoolType): - return False - - return self._base_equal(other) - - def __hash__(self) -> int: - return hash(("PsBoolType", self._const)) - + class PsIntegerType(PsScalarType, ABC): """Signed and unsigned integer types. @@ -574,20 +572,7 @@ class PsIntegerType(PsScalarType, ABC): return np_type(value) raise PsTypeError(f"Could not interpret {value} as {repr(self)}") - - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsIntegerType): - return False - - return ( - self._base_equal(other) - and self._width == other._width - and self._signed == other._signed - ) - - def __hash__(self) -> int: - return hash(("PsIntegerType", self._width, self._signed, self._const)) - + def c_string(self) -> str: prefix = "" if self._signed else "u" return f"{self._const_string()}{prefix}int{self._width}_t" @@ -612,6 +597,14 @@ class PsSignedIntegerType(PsIntegerType): def __init__(self, width: int, const: bool = False): super().__init__(width, True, const) + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsSignedIntegerType(32, True) + >>> t == PsSignedIntegerType(*t.__args__()) + True + """ + return (self._width, self._const) + @final class PsUnsignedIntegerType(PsIntegerType): @@ -629,6 +622,14 @@ class PsUnsignedIntegerType(PsIntegerType): def __init__(self, width: int, const: bool = False): super().__init__(width, False, const) + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsUnsignedIntegerType(32, True) + >>> t == PsUnsignedIntegerType(*t.__args__()) + True + """ + return (self._width, self._const) + @final class PsIeeeFloatType(PsScalarType): @@ -653,6 +654,14 @@ class PsIeeeFloatType(PsScalarType): super().__init__(const) self._width = width + def __args__(self) -> tuple[Any, ...]: + """ + >>> t = PsIeeeFloatType(32, True) + >>> t == PsIeeeFloatType(*t.__args__()) + True + """ + return (self._width, self._const) + @property def width(self) -> int: return self._width @@ -698,14 +707,6 @@ class PsIeeeFloatType(PsScalarType): raise PsTypeError(f"Could not interpret {value} as {repr(self)}") - def __eq__(self, other: object) -> bool: - if not isinstance(other, PsIeeeFloatType): - return False - return self._base_equal(other) and self._width == other._width - - def __hash__(self) -> int: - return hash(("PsIeeeFloatType", self._width, self._const)) - def c_string(self) -> str: match self._width: case 16: diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py index 204ee24cf..24a46ab90 100644 --- a/tests/nbackend/types/test_types.py +++ b/tests/nbackend/types/test_types.py @@ -22,9 +22,10 @@ def test_parsing_positive(): assert create_type("const uint32_t * restrict") == Ptr( UInt(32, const=True), restrict=True ) - assert create_type("float * * const") == Ptr(Ptr(Fp(32)), const=True) - assert create_type("uint16 * const") == Ptr(UInt(16), const=True) - assert create_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True) + assert create_type("float * * const") == Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=False) + assert create_type("float * * restrict const") == Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=True) + assert create_type("uint16 * const") == Ptr(UInt(16), const=True, restrict=False) + assert create_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True, restrict=False) def test_parsing_negative(): -- GitLab