diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index c973097eb87f6c3c0c202562fd7709464806ccf1..e6fc4bb78ddfe18a5ac572700ec7d59d97fd84cf 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -1,6 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import final, Any, Sequence +from typing import final, Any, Sequence, SupportsIndex from dataclasses import dataclass import numpy as np @@ -105,8 +105,14 @@ class PsArrayType(PsDereferencableType): """ def __init__( - self, element_type: PsType, shape: Sequence[int], const: bool = False + self, element_type: PsType, shape: SupportsIndex | Sequence[SupportsIndex], const: bool = False ): + from operator import index + if isinstance(shape, SupportsIndex): + shape = (index(shape),) + else: + shape = tuple(index(s) for s in shape) + if not shape or any(s <= 0 for s in shape): raise ValueError(f"Invalid array shape: {shape}") @@ -115,7 +121,7 @@ class PsArrayType(PsDereferencableType): element_type = deconstify(element_type) - self._shape = tuple(shape) + self._shape = shape super().__init__(element_type, const) def __args__(self) -> tuple[Any, ...]: diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py index 200db7a2ec4ec38857bc2dfb84a75d23b64c385a..165d572de5d191e759e5d8a6bea06c0f71884374 100644 --- a/tests/nbackend/types/test_types.py +++ b/tests/nbackend/types/test_types.py @@ -152,12 +152,14 @@ def test_struct_types(): def test_array_types(): - t = PsArrayType(UInt(64), [42]) + t = PsArrayType(UInt(64), 42) assert t.dim == 1 assert t.shape == (42,) assert not t.const assert t.c_string() == "uint64_t[42]" + assert t == PsArrayType(UInt(64), (42,)) + t = PsArrayType(UInt(64), [3, 4, 5]) assert t.dim == 3 assert t.shape == (3, 4, 5)