Skip to content
Snippets Groups Projects
Commit ee64e7e1 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

array associated symbols cleanup

parent 0b14e421
No related branches found
No related tags found
No related merge requests found
Pipeline #63215 failed
...@@ -5,7 +5,7 @@ An array has a fixed name, dimensionality, and element type, as well as a number ...@@ -5,7 +5,7 @@ An array has a fixed name, dimensionality, and element type, as well as a number
variables. variables.
The associated variables are the *shape* and *strides* of the array, modelled by the The associated variables are the *shape* and *strides* of the array, modelled by the
`PsArrayShapeVar` and `PsArrayStrideVar` classes. They have integer type and are used to `PsArrayShapeSymbol` and `PsArrayStrideSymbol` classes. They have integer type and are used to
reason about the array's memory layout. reason about the array's memory layout.
...@@ -68,7 +68,7 @@ class PsLinearizedArray: ...@@ -68,7 +68,7 @@ class PsLinearizedArray:
For constant entries, their value must be given as an integer. For constant entries, their value must be given as an integer.
For variable shape entries and strides, the Ellipsis `...` must be passed instead. For variable shape entries and strides, the Ellipsis `...` must be passed instead.
Internally, the passed ``index_dtype`` will be used to create typed constants (`PsTypedConstant`) Internally, the passed ``index_dtype`` will be used to create typed constants (`PsTypedConstant`)
and variables (`PsArrayShapeVar` and `PsArrayStrideVar`) from the passed values. and variables (`PsArrayShapeSymbol` and `PsArrayStrideSymbol`) from the passed values.
""" """
def __init__( def __init__(
...@@ -86,18 +86,18 @@ class PsLinearizedArray: ...@@ -86,18 +86,18 @@ class PsLinearizedArray:
if len(shape) != len(strides): if len(shape) != len(strides):
raise ValueError("Shape and stride tuples must have the same length") raise ValueError("Shape and stride tuples must have the same length")
self._shape: tuple[PsArrayShapeVar | PsConstant, ...] = tuple( self._shape: tuple[PsArrayShapeSymbol | PsConstant, ...] = tuple(
( (
PsArrayShapeVar(self, i, index_dtype) PsArrayShapeSymbol(self, i, index_dtype)
if s == Ellipsis if s == Ellipsis
else PsConstant(s, index_dtype) else PsConstant(s, index_dtype)
) )
for i, s in enumerate(shape) for i, s in enumerate(shape)
) )
self._strides: tuple[PsArrayStrideVar | PsConstant, ...] = tuple( self._strides: tuple[PsArrayStrideSymbol | PsConstant, ...] = tuple(
( (
PsArrayStrideVar(self, i, index_dtype) PsArrayStrideSymbol(self, i, index_dtype)
if s == Ellipsis if s == Ellipsis
else PsConstant(s, index_dtype) else PsConstant(s, index_dtype)
) )
...@@ -117,8 +117,8 @@ class PsLinearizedArray: ...@@ -117,8 +117,8 @@ class PsLinearizedArray:
return self._base_ptr return self._base_ptr
@property @property
def shape(self) -> tuple[PsArrayShapeVar | PsConstant, ...]: def shape(self) -> tuple[PsArrayShapeSymbol | PsConstant, ...]:
"""The array's shape, expressed using `PsTypedConstant` and `PsArrayShapeVar`""" """The array's shape, expressed using `PsTypedConstant` and `PsArrayShapeSymbol`"""
return self._shape return self._shape
@property @property
...@@ -129,8 +129,8 @@ class PsLinearizedArray: ...@@ -129,8 +129,8 @@ class PsLinearizedArray:
) )
@property @property
def strides(self) -> tuple[PsArrayStrideVar | PsConstant, ...]: def strides(self) -> tuple[PsArrayStrideSymbol | PsConstant, ...]:
"""The array's strides, expressed using `PsTypedConstant` and `PsArrayStrideVar`""" """The array's strides, expressed using `PsTypedConstant` and `PsArrayStrideSymbol`"""
return self._strides return self._strides
@property @property
...@@ -144,32 +144,6 @@ class PsLinearizedArray: ...@@ -144,32 +144,6 @@ class PsLinearizedArray:
def element_type(self): def element_type(self):
return self._element_type return self._element_type
def _hashable_contents(self):
"""Contents by which to compare two instances of `PsLinearizedArray`.
Since equality checks on shape and stride variables internally check equality of their associated arrays,
if these variables would occur in here, an infinite recursion would follow.
Hence they are filtered and replaced by the ellipsis.
"""
shape_clean = self.shape_spec
strides_clean = self.strides_spec
return (
self._name,
self._element_type,
self._index_dtype,
shape_clean,
strides_clean,
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsLinearizedArray):
return False
return self._hashable_contents() == other._hashable_contents()
def __hash__(self) -> int:
return hash(self._hashable_contents())
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])" f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])"
...@@ -182,24 +156,18 @@ class PsArrayAssocSymbol(PsSymbol, ABC): ...@@ -182,24 +156,18 @@ class PsArrayAssocSymbol(PsSymbol, ABC):
Instances of this class represent pointers and indexing information bound Instances of this class represent pointers and indexing information bound
to a particular array. to a particular array.
""" """
init_arg_names: tuple[str, ...] = ("name", "dtype", "array")
__match_args__ = ("name", "dtype", "array") __match_args__ = ("name", "dtype", "array")
def __init__(self, name: str, dtype: PsAbstractType, array: PsLinearizedArray): def __init__(self, name: str, dtype: PsAbstractType, array: PsLinearizedArray):
super().__init__(name, dtype) super().__init__(name, dtype)
self._array = array self._array = array
def __getinitargs__(self):
return self.name, self.dtype, self.array
@property @property
def array(self) -> PsLinearizedArray: def array(self) -> PsLinearizedArray:
return self._array return self._array
class PsArrayBasePointer(PsArrayAssocSymbol): class PsArrayBasePointer(PsArrayAssocSymbol):
init_arg_names: tuple[str, ...] = ("name", "array")
__match_args__ = ("name", "array") __match_args__ = ("name", "array")
def __init__(self, name: str, array: PsLinearizedArray): def __init__(self, name: str, array: PsLinearizedArray):
...@@ -208,9 +176,6 @@ class PsArrayBasePointer(PsArrayAssocSymbol): ...@@ -208,9 +176,6 @@ class PsArrayBasePointer(PsArrayAssocSymbol):
self._array = array self._array = array
def __getinitargs__(self):
return self.name, self.array
class TypeErasedBasePointer(PsArrayBasePointer): class TypeErasedBasePointer(PsArrayBasePointer):
"""Base pointer for arrays whose element type has been erased. """Base pointer for arrays whose element type has been erased.
...@@ -224,14 +189,13 @@ class TypeErasedBasePointer(PsArrayBasePointer): ...@@ -224,14 +189,13 @@ class TypeErasedBasePointer(PsArrayBasePointer):
self._array = array self._array = array
class PsArrayShapeVar(PsArrayAssocSymbol): class PsArrayShapeSymbol(PsArrayAssocSymbol):
"""Variable that represents an array's shape in one coordinate. """Variable that represents an array's shape in one coordinate.
Do not instantiate this class yourself, but only use its instances Do not instantiate this class yourself, but only use its instances
as provided by `PsLinearizedArray.shape`. as provided by `PsLinearizedArray.shape`.
""" """
init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype")
__match_args__ = ("array", "coordinate", "dtype") __match_args__ = ("array", "coordinate", "dtype")
def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType): def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType):
...@@ -243,18 +207,13 @@ class PsArrayShapeVar(PsArrayAssocSymbol): ...@@ -243,18 +207,13 @@ class PsArrayShapeVar(PsArrayAssocSymbol):
def coordinate(self) -> int: def coordinate(self) -> int:
return self._coordinate return self._coordinate
def __getinitargs__(self):
return self.array, self.coordinate, self.dtype
class PsArrayStrideSymbol(PsArrayAssocSymbol):
class PsArrayStrideVar(PsArrayAssocSymbol):
"""Variable that represents an array's stride in one coordinate. """Variable that represents an array's stride in one coordinate.
Do not instantiate this class yourself, but only use its instances Do not instantiate this class yourself, but only use its instances
as provided by `PsLinearizedArray.strides`. as provided by `PsLinearizedArray.strides`.
""" """
init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype")
__match_args__ = ("array", "coordinate", "dtype") __match_args__ = ("array", "coordinate", "dtype")
def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType): def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType):
...@@ -265,6 +224,3 @@ class PsArrayStrideVar(PsArrayAssocSymbol): ...@@ -265,6 +224,3 @@ class PsArrayStrideVar(PsArrayAssocSymbol):
@property @property
def coordinate(self) -> int: def coordinate(self) -> int:
return self._coordinate return self._coordinate
def __getinitargs__(self):
return self.array, self.coordinate, self.dtype
...@@ -17,8 +17,8 @@ from ..arrays import ( ...@@ -17,8 +17,8 @@ from ..arrays import (
PsLinearizedArray, PsLinearizedArray,
PsArrayAssocSymbol, PsArrayAssocSymbol,
PsArrayBasePointer, PsArrayBasePointer,
PsArrayShapeVar, PsArrayShapeSymbol,
PsArrayStrideVar, PsArrayStrideSymbol,
) )
from ..types import ( from ..types import (
PsAbstractType, PsAbstractType,
...@@ -290,12 +290,12 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ ...@@ -290,12 +290,12 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
match variable: match variable:
case PsArrayBasePointer(): case PsArrayBasePointer():
code = f"{variable.dtype} {variable.name} = ({variable.dtype}) {buffer}.buf;" code = f"{variable.dtype} {variable.name} = ({variable.dtype}) {buffer}.buf;"
case PsArrayShapeVar(): case PsArrayShapeSymbol():
coord = variable.coordinate coord = variable.coordinate
code = ( code = (
f"{variable.dtype} {variable.name} = {buffer}.shape[{coord}];" f"{variable.dtype} {variable.name} = {buffer}.shape[{coord}];"
) )
case PsArrayStrideVar(): case PsArrayStrideSymbol():
coord = variable.coordinate coord = variable.coordinate
code = ( code = (
f"{variable.dtype} {variable.name} = " f"{variable.dtype} {variable.name} = "
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment