Skip to content
Snippets Groups Projects
Commit 95a195ea authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Add scalar shapes to array type; generalize to SupportsIndex

parent c3ed7ca5
No related branches found
No related tags found
1 merge request!420Revised Array Modelling & Memory Model
Pipeline #69604 passed
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import final, Any, Sequence from typing import final, Any, Sequence, SupportsIndex
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
...@@ -105,8 +105,14 @@ class PsArrayType(PsDereferencableType): ...@@ -105,8 +105,14 @@ class PsArrayType(PsDereferencableType):
""" """
def __init__( def __init__(
self, element_type: PsType, shape: Sequence[int], const: bool = False self, element_type: PsType, shape: SupportsIndex | Sequence[SupportsIndex], const: bool = False
): ):
from operator import index
if isinstance(shape, SupportsIndex):
shape = (index(shape),)
else:
shape = tuple(index(s) for s in shape)
if not shape or any(s <= 0 for s in shape): if not shape or any(s <= 0 for s in shape):
raise ValueError(f"Invalid array shape: {shape}") raise ValueError(f"Invalid array shape: {shape}")
...@@ -115,7 +121,7 @@ class PsArrayType(PsDereferencableType): ...@@ -115,7 +121,7 @@ class PsArrayType(PsDereferencableType):
element_type = deconstify(element_type) element_type = deconstify(element_type)
self._shape = tuple(shape) self._shape = shape
super().__init__(element_type, const) super().__init__(element_type, const)
def __args__(self) -> tuple[Any, ...]: def __args__(self) -> tuple[Any, ...]:
......
...@@ -152,12 +152,14 @@ def test_struct_types(): ...@@ -152,12 +152,14 @@ def test_struct_types():
def test_array_types(): def test_array_types():
t = PsArrayType(UInt(64), [42]) t = PsArrayType(UInt(64), 42)
assert t.dim == 1 assert t.dim == 1
assert t.shape == (42,) assert t.shape == (42,)
assert not t.const assert not t.const
assert t.c_string() == "uint64_t[42]" assert t.c_string() == "uint64_t[42]"
assert t == PsArrayType(UInt(64), (42,))
t = PsArrayType(UInt(64), [3, 4, 5]) t = PsArrayType(UInt(64), [3, 4, 5])
assert t.dim == 3 assert t.dim == 3
assert t.shape == (3, 4, 5) assert t.shape == (3, 4, 5)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment