From d6eb671a178a4e5d2d1038bd636b1832694a9e71 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 3 Apr 2024 12:06:38 +0200
Subject: [PATCH] refactor types hashing and equality

---
 src/pystencils/types/basic_types.py | 193 ++++++++++++++--------------
 tests/nbackend/types/test_types.py  |   7 +-
 2 files changed, 101 insertions(+), 99 deletions(-)

diff --git a/src/pystencils/types/basic_types.py b/src/pystencils/types/basic_types.py
index 3678ea126..ebfe9d610 100644
--- a/src/pystencils/types/basic_types.py
+++ b/src/pystencils/types/basic_types.py
@@ -1,6 +1,6 @@
 from __future__ import annotations
 from abc import ABC, abstractmethod
-from typing import final, TypeVar, Any, Sequence
+from typing import final, TypeVar, Any, Sequence, cast
 from dataclasses import dataclass
 from copy import copy
 
@@ -46,11 +46,22 @@ class PsType(ABC):
         return None
 
     #   -------------------------------------------------------------------------------------------
-    #   Internal virtual operations
+    #   Internal operations
     #   -------------------------------------------------------------------------------------------
 
-    def _base_equal(self, other: PsType) -> bool:
-        return type(self) is type(other) and self._const == other._const
+    @abstractmethod
+    def __args__(self) -> tuple[Any, ...]:
+        """Arguments to this type.
+        
+        The tuple returned by this method is used to serialize, deserialize, and check equality of types.
+        For each instantiable subclass ``MyType`` of ``PsType``, the following must hold:
+
+        ```
+        t = MyType(< arguments >)
+        assert MyType(*t.__args__()) == t
+        ```
+        """
+        pass
 
     def _const_string(self) -> str:
         return "const " if self._const else ""
@@ -63,16 +74,21 @@ class PsType(ABC):
     #   Dunder Methods
     #   -------------------------------------------------------------------------------------------
 
-    @abstractmethod
     def __eq__(self, other: object) -> bool:
-        pass
+        if self is other:
+            return True
+        
+        if type(self) is not type(other):
+            return False
+        
+        other = cast(PsType, other)
+        return self.__args__() == other.__args__()
 
     def __str__(self) -> str:
         return self.c_string()
 
-    @abstractmethod
     def __hash__(self) -> int:
-        pass
+        return hash((type(self), self.__args__()))
 
 
 class PsCustomType(PsType):
@@ -92,13 +108,13 @@ class PsCustomType(PsType):
     def name(self) -> str:
         return self._name
 
-    def __eq__(self, other: object) -> bool:
-        if not isinstance(other, PsCustomType):
-            return False
-        return self._base_equal(other) and self._name == other._name
-
-    def __hash__(self) -> int:
-        return hash(("PsCustomType", self._name, self._const))
+    def __args__(self) -> tuple[Any, ...]:
+        """
+        >>> t = PsCustomType("std::vector< int >")
+        >>> t == PsCustomType(*t.__args__())
+        True
+        """
+        return (self._name,)
 
     def c_string(self) -> str:
         return f"{self._const_string()} {self._name}"
@@ -142,18 +158,18 @@ class PsPointerType(PsDereferencableType):
         super().__init__(base_type, const)
         self._restrict = restrict
 
+    def __args__(self) -> tuple[Any, ...]:
+        """
+        >>> t = PsPointerType(PsBoolType(), const=True)
+        >>> t == PsPointerType(*t.__args__())
+        True
+        """
+        return (self._base_type, self._const, self._restrict)
+
     @property
     def restrict(self) -> bool:
         return self._restrict
 
-    def __eq__(self, other: object) -> bool:
-        if not isinstance(other, PsPointerType):
-            return False
-        return self._base_equal(other) and self._base_type == other._base_type
-
-    def __hash__(self) -> int:
-        return hash(("PsPointerType", self._base_type, self._restrict, self._const))
-
     def c_string(self) -> str:
         base_str = self._base_type.c_string()
         restrict_str = " RESTRICT" if self._restrict else ""
@@ -172,6 +188,14 @@ class PsArrayType(PsDereferencableType):
         self._length = length
         super().__init__(base_type, const)
 
+    def __args__(self) -> tuple[Any, ...]:
+        """
+        >>> t = PsArrayType(PsBoolType(), 13, const=True)
+        >>> t == PsArrayType(*t.__args__())
+        True
+        """
+        return (self._base_type, self._length, self._const)
+
     @property
     def length(self) -> int | None:
         return self._length
@@ -179,19 +203,6 @@ class PsArrayType(PsDereferencableType):
     def c_string(self) -> str:
         return f"{self._base_type.c_string()} [{str(self._length) if self._length is not None else ''}]"
 
-    def __eq__(self, other: object) -> bool:
-        if not isinstance(other, PsArrayType):
-            return False
-
-        return (
-            self._base_equal(other)
-            and self._base_type == other._base_type
-            and self._length == other._length
-        )
-
-    def __hash__(self) -> int:
-        return hash(("PsArrayType", self._base_type, self._length, self._const))
-
     def __repr__(self) -> str:
         return f"PsArrayType(element_type={repr(self._base_type)}, size={self._length}, const={self._const})"
 
@@ -229,6 +240,14 @@ class PsStructType(PsType):
                 raise ValueError(f"Duplicate struct member name: {member.name}")
             names.add(member.name)
 
+    def __args__(self) -> tuple[Any, ...]:
+        """
+        >>> t = PsStructType([("idx", PsSignedIntegerType(32)), ("val", PsBoolType())], "sname")
+        >>> t == PsStructType(*t.__args__())
+        True
+        """
+        return (self._members, self._name, self._const)
+
     @property
     def members(self) -> tuple[PsStructType.Member, ...]:
         return self._members
@@ -276,19 +295,6 @@ class PsStructType(PsType):
         else:
             return self._name
 
-    def __eq__(self, other: object) -> bool:
-        if not isinstance(other, PsStructType):
-            return False
-
-        return (
-            self._base_equal(other)
-            and self._name == other._name
-            and self._members == other._members
-        )
-
-    def __hash__(self) -> int:
-        return hash(("PsStructTupe", self._name, self._members, self._const))
-
     def __repr__(self) -> str:
         members = ", ".join(f"{m.dtype} {m.name}" for m in self._members)
         name = "<anonymous>" if self.anonymous else f"name={self._name}"
@@ -386,6 +392,14 @@ class PsVectorType(PsNumericType):
         self._vector_entries = vector_entries
         self._scalar_type = constify(scalar_type) if const else deconstify(scalar_type)
 
+    def __args__(self) -> tuple[Any, ...]:
+        """
+        >>> t = PsVectorType(PsBoolType(), 8, True)
+        >>> t == PsVectorType(*t.__args__())
+        True
+        """
+        return (self._scalar_type, self._vector_entries, self._const)
+
     @property
     def scalar_type(self) -> PsScalarType:
         return self._scalar_type
@@ -437,21 +451,6 @@ class PsVectorType(PsNumericType):
             [element] * self._vector_entries, dtype=self.scalar_type.numpy_dtype
         )
 
-    def __eq__(self, other: object) -> bool:
-        if not isinstance(other, PsVectorType):
-            return False
-
-        return (
-            self._base_equal(other)
-            and self._scalar_type == other._scalar_type
-            and self._vector_entries == other._vector_entries
-        )
-
-    def __hash__(self) -> int:
-        return hash(
-            ("PsVectorType", self._scalar_type, self._vector_entries, self._const)
-        )
-
     def c_string(self) -> str:
         raise PsTypeError("Cannot retrieve C type string for generic vector types.")
 
@@ -473,6 +472,14 @@ class PsBoolType(PsScalarType):
     def __init__(self, const: bool = False):
         super().__init__(const)
 
+    def __args__(self) -> tuple[Any, ...]:
+        """
+        >>> t = PsBoolType(True)
+        >>> t == PsBoolType(*t.__args__())
+        True
+        """
+        return (self._const,)
+
     @property
     def width(self) -> int:
         return 8
@@ -506,16 +513,7 @@ class PsBoolType(PsScalarType):
 
     def c_string(self) -> str:
         return "bool"
-
-    def __eq__(self, other: object) -> bool:
-        if not isinstance(other, PsBoolType):
-            return False
-
-        return self._base_equal(other)
-
-    def __hash__(self) -> int:
-        return hash(("PsBoolType", self._const))
-
+    
 
 class PsIntegerType(PsScalarType, ABC):
     """Signed and unsigned integer types.
@@ -574,20 +572,7 @@ class PsIntegerType(PsScalarType, ABC):
             return np_type(value)
 
         raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
-
-    def __eq__(self, other: object) -> bool:
-        if not isinstance(other, PsIntegerType):
-            return False
-
-        return (
-            self._base_equal(other)
-            and self._width == other._width
-            and self._signed == other._signed
-        )
-
-    def __hash__(self) -> int:
-        return hash(("PsIntegerType", self._width, self._signed, self._const))
-
+    
     def c_string(self) -> str:
         prefix = "" if self._signed else "u"
         return f"{self._const_string()}{prefix}int{self._width}_t"
@@ -612,6 +597,14 @@ class PsSignedIntegerType(PsIntegerType):
     def __init__(self, width: int, const: bool = False):
         super().__init__(width, True, const)
 
+    def __args__(self) -> tuple[Any, ...]:
+        """
+        >>> t = PsSignedIntegerType(32, True)
+        >>> t == PsSignedIntegerType(*t.__args__())
+        True
+        """
+        return (self._width, self._const)
+
 
 @final
 class PsUnsignedIntegerType(PsIntegerType):
@@ -629,6 +622,14 @@ class PsUnsignedIntegerType(PsIntegerType):
     def __init__(self, width: int, const: bool = False):
         super().__init__(width, False, const)
 
+    def __args__(self) -> tuple[Any, ...]:
+        """
+        >>> t = PsUnsignedIntegerType(32, True)
+        >>> t == PsUnsignedIntegerType(*t.__args__())
+        True
+        """
+        return (self._width, self._const)
+
 
 @final
 class PsIeeeFloatType(PsScalarType):
@@ -653,6 +654,14 @@ class PsIeeeFloatType(PsScalarType):
         super().__init__(const)
         self._width = width
 
+    def __args__(self) -> tuple[Any, ...]:
+        """
+        >>> t = PsIeeeFloatType(32, True)
+        >>> t == PsIeeeFloatType(*t.__args__())
+        True
+        """
+        return (self._width, self._const)
+
     @property
     def width(self) -> int:
         return self._width
@@ -698,14 +707,6 @@ class PsIeeeFloatType(PsScalarType):
 
         raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
 
-    def __eq__(self, other: object) -> bool:
-        if not isinstance(other, PsIeeeFloatType):
-            return False
-        return self._base_equal(other) and self._width == other._width
-
-    def __hash__(self) -> int:
-        return hash(("PsIeeeFloatType", self._width, self._const))
-
     def c_string(self) -> str:
         match self._width:
             case 16:
diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py
index 204ee24cf..24a46ab90 100644
--- a/tests/nbackend/types/test_types.py
+++ b/tests/nbackend/types/test_types.py
@@ -22,9 +22,10 @@ def test_parsing_positive():
     assert create_type("const uint32_t * restrict") == Ptr(
         UInt(32, const=True), restrict=True
     )
-    assert create_type("float * * const") == Ptr(Ptr(Fp(32)), const=True)
-    assert create_type("uint16 * const") == Ptr(UInt(16), const=True)
-    assert create_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True)
+    assert create_type("float * * const") == Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=False)
+    assert create_type("float * * restrict const") == Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=True)
+    assert create_type("uint16 * const") == Ptr(UInt(16), const=True, restrict=False)
+    assert create_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True, restrict=False)
 
 
 def test_parsing_negative():
-- 
GitLab