Skip to content
Snippets Groups Projects
arrays.py 5.77 KiB
from __future__ import annotations

from typing import Sequence
from types import EllipsisType

from abc import ABC

from .constants import PsConstant
from ..types import (
    PsType,
    PsPointerType,
    PsIntegerType,
    PsUnsignedIntegerType,
)

from .symbols import PsSymbol
from ..defaults import DEFAULTS


class PsLinearizedArray:
    """Class to model N-dimensional contiguous arrays.

    **Memory Layout, Shape and Strides**

    The memory layout of an array is defined by its shape and strides.
    Both shape and stride entries may either be constants or special variables associated with
    exactly one array.

    Shape and strides may be specified at construction in the following way.
    For constant entries, their value must be given as an integer.
    For variable shape entries and strides, the Ellipsis `...` must be passed instead.
    Internally, the passed ``index_dtype`` will be used to create typed constants (`PsConstant`)
    and variables (`PsArrayShapeSymbol` and `PsArrayStrideSymbol`) from the passed values.
    """

    def __init__(
        self,
        name: str,
        element_type: PsType,
        shape: Sequence[int | EllipsisType],
        strides: Sequence[int | EllipsisType],
        index_dtype: PsIntegerType = DEFAULTS.index_dtype,
    ):
        self._name = name
        self._element_type = element_type
        self._index_dtype = index_dtype

        if len(shape) != len(strides):
            raise ValueError("Shape and stride tuples must have the same length")

        self._shape: tuple[PsArrayShapeSymbol | PsConstant, ...] = tuple(
            (
                PsArrayShapeSymbol(self, i, index_dtype)
                if s == Ellipsis
                else PsConstant(s, index_dtype)
            )
            for i, s in enumerate(shape)
        )

        self._strides: tuple[PsArrayStrideSymbol | PsConstant, ...] = tuple(
            (
                PsArrayStrideSymbol(self, i, index_dtype)
                if s == Ellipsis
                else PsConstant(s, index_dtype)
            )
            for i, s in enumerate(strides)
        )

        self._base_ptr = PsArrayBasePointer(f"{self._name}_data", self)

    @property
    def name(self):
        """The array's name"""
        return self._name

    @property
    def base_pointer(self) -> PsArrayBasePointer:
        """The array's base pointer"""
        return self._base_ptr

    @property
    def shape(self) -> tuple[PsArrayShapeSymbol | PsConstant, ...]:
        """The array's shape, expressed using `PsConstant` and `PsArrayShapeSymbol`"""
        return self._shape

    @property
    def shape_spec(self) -> tuple[EllipsisType | int, ...]:
        """The array's shape, expressed using `int` and `...`"""
        return tuple(
            (s.value if isinstance(s, PsConstant) else ...) for s in self._shape
        )

    @property
    def strides(self) -> tuple[PsArrayStrideSymbol | PsConstant, ...]:
        """The array's strides, expressed using `PsConstant` and `PsArrayStrideSymbol`"""
        return self._strides

    @property
    def strides_spec(self) -> tuple[EllipsisType | int, ...]:
        """The array's strides, expressed using `int` and `...`"""
        return tuple(
            (s.value if isinstance(s, PsConstant) else ...) for s in self._strides
        )

    @property
    def element_type(self):
        return self._element_type

    def __repr__(self) -> str:
        return (
            f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])"
        )


class PsArrayAssocSymbol(PsSymbol, ABC):
    """A variable that is associated to an array.

    Instances of this class represent pointers and indexing information bound
    to a particular array.
    """

    __match_args__ = ("name", "dtype", "array")

    def __init__(self, name: str, dtype: PsType, array: PsLinearizedArray):
        super().__init__(name, dtype)
        self._array = array

    @property
    def array(self) -> PsLinearizedArray:
        return self._array


class PsArrayBasePointer(PsArrayAssocSymbol):
    def __init__(self, name: str, array: PsLinearizedArray):
        dtype = PsPointerType(array.element_type)
        super().__init__(name, dtype, array)

        self._array = array


class TypeErasedBasePointer(PsArrayBasePointer):
    """Base pointer for arrays whose element type has been erased.

    Used primarily for arrays of anonymous structs."""

    def __init__(self, name: str, array: PsLinearizedArray):
        dtype = PsPointerType(PsUnsignedIntegerType(8))
        super(PsArrayBasePointer, self).__init__(name, dtype, array)

        self._array = array


class PsArrayShapeSymbol(PsArrayAssocSymbol):
    """Variable that represents an array's shape in one coordinate.

    Do not instantiate this class yourself, but only use its instances
    as provided by `PsLinearizedArray.shape`.
    """

    __match_args__ = PsArrayAssocSymbol.__match_args__ + ("coordinate",)

    def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType):
        name = f"_size_{array.name}_{coordinate}"
        super().__init__(name, dtype, array)
        self._coordinate = coordinate

    @property
    def coordinate(self) -> int:
        return self._coordinate


class PsArrayStrideSymbol(PsArrayAssocSymbol):
    """Variable that represents an array's stride in one coordinate.

    Do not instantiate this class yourself, but only use its instances
    as provided by `PsLinearizedArray.strides`.
    """

    __match_args__ = PsArrayAssocSymbol.__match_args__ + ("coordinate",)

    def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType):
        name = f"_stride_{array.name}_{coordinate}"
        super().__init__(name, dtype, array)
        self._coordinate = coordinate

    @property
    def coordinate(self) -> int:
        return self._coordinate