From 95a195eae441cb23db336b044a852eaf05543981 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 15 Oct 2024 16:38:20 +0200
Subject: [PATCH] Add scalar shapes to array type; generalize to SupportsIndex

---
 src/pystencils/types/types.py      | 12 +++++++++---
 tests/nbackend/types/test_types.py |  4 +++-
 2 files changed, 12 insertions(+), 4 deletions(-)

diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py
index c973097eb..e6fc4bb78 100644
--- a/src/pystencils/types/types.py
+++ b/src/pystencils/types/types.py
@@ -1,6 +1,6 @@
 from __future__ import annotations
 from abc import ABC, abstractmethod
-from typing import final, Any, Sequence
+from typing import final, Any, Sequence, SupportsIndex
 from dataclasses import dataclass
 
 import numpy as np
@@ -105,8 +105,14 @@ class PsArrayType(PsDereferencableType):
     """
 
     def __init__(
-        self, element_type: PsType, shape: Sequence[int], const: bool = False
+        self, element_type: PsType, shape: SupportsIndex | Sequence[SupportsIndex], const: bool = False
     ):
+        from operator import index
+        if isinstance(shape, SupportsIndex):
+            shape = (index(shape),)
+        else:
+            shape = tuple(index(s) for s in shape)
+
         if not shape or any(s <= 0 for s in shape):
             raise ValueError(f"Invalid array shape: {shape}")
         
@@ -115,7 +121,7 @@ class PsArrayType(PsDereferencableType):
         
         element_type = deconstify(element_type)
 
-        self._shape = tuple(shape)
+        self._shape = shape
         super().__init__(element_type, const)
 
     def __args__(self) -> tuple[Any, ...]:
diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py
index 200db7a2e..165d572de 100644
--- a/tests/nbackend/types/test_types.py
+++ b/tests/nbackend/types/test_types.py
@@ -152,12 +152,14 @@ def test_struct_types():
 
 
 def test_array_types():
-    t = PsArrayType(UInt(64), [42])
+    t = PsArrayType(UInt(64), 42)
     assert t.dim == 1
     assert t.shape == (42,)
     assert not t.const
     assert t.c_string() == "uint64_t[42]"
 
+    assert t == PsArrayType(UInt(64), (42,))
+
     t = PsArrayType(UInt(64), [3, 4, 5])
     assert t.dim == 3
     assert t.shape == (3, 4, 5)
-- 
GitLab