From 4bbfb3b65bb91ca9e697ec1abfde84ffa70b446e Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 21 Oct 2024 08:48:31 +0200
Subject: [PATCH] Revised Array Modelling & Memory Model

---
 .../source/api/symbolic_language/astnodes.rst |   2 +-
 docs/source/backend/objects.rst               |  41 ++-
 src/pystencils/backend/ast/expressions.py     | 161 ++++++++---
 src/pystencils/backend/ast/util.py            |  49 +++-
 src/pystencils/backend/emission.py            |  51 ++--
 .../backend/kernelcreation/ast_factory.py     |   6 +-
 .../backend/kernelcreation/freeze.py          |  35 ++-
 .../backend/kernelcreation/typification.py    | 253 +++++++++---------
 src/pystencils/backend/literals.py            |   2 +-
 src/pystencils/backend/platforms/sycl.py      |   4 +-
 src/pystencils/backend/platforms/x86.py       |   6 +-
 .../erase_anonymous_structs.py                |   7 +-
 .../hoist_loop_invariant_decls.py             |  10 +-
 src/pystencils/sympyextensions/pointers.py    |  11 +
 src/pystencils/types/types.py                 |  54 +++-
 tests/nbackend/kernelcreation/test_freeze.py  | 101 +++++++
 .../kernelcreation/test_typification.py       | 240 +++++++++++++----
 tests/nbackend/test_ast.py                    |  11 +-
 tests/nbackend/test_code_printing.py          |  34 ++-
 tests/nbackend/test_extensions.py             |   8 +-
 .../transformations/test_canonical_clone.py   |   4 +-
 .../transformations/test_hoist_invariants.py  |   4 +-
 tests/nbackend/types/test_types.py            |  46 +++-
 23 files changed, 853 insertions(+), 287 deletions(-)

diff --git a/docs/source/api/symbolic_language/astnodes.rst b/docs/source/api/symbolic_language/astnodes.rst
index 4d5c4b89f..ff31c98ec 100644
--- a/docs/source/api/symbolic_language/astnodes.rst
+++ b/docs/source/api/symbolic_language/astnodes.rst
@@ -4,6 +4,6 @@ Kernel Structure
 
 .. automodule:: pystencils.sympyextensions.astnodes
 
-.. autoclass:: pystencils.sympyextensions.AssignmentCollection
+.. autoclass:: pystencils.AssignmentCollection
     :members:
 
diff --git a/docs/source/backend/objects.rst b/docs/source/backend/objects.rst
index b0c3af6db..1b36842b8 100644
--- a/docs/source/backend/objects.rst
+++ b/docs/source/backend/objects.rst
@@ -1,15 +1,44 @@
-*****************************
-Symbols, Constants and Arrays
-*****************************
+****************************
+Constants and Memory Objects
+****************************
+
+Memory Objects: Symbols and Field Arrays
+========================================
+
+The Memory Model
+----------------
+
+In order to reason about memory accesses, mutability, invariance, and aliasing, the *pystencils* backend uses
+a very simple memory model. There are three types of memory objects:
+
+- Symbols (`PsSymbol`), which act as registers for data storage within the scope of a kernel
+- Field arrays (`PsLinearizedArray`), which represent a contiguous block of memory the kernel has access to, and
+- the *unmanaged heap*, which is a global catch-all memory object which all pointers not belonging to a field
+  array point into.
+
+All of these objects are disjoint, and cannot alias each other.
+Each symbol exists in isolation,
+field arrays do not overlap,
+and raw pointers are assumed not to point into memory owned by a symbol or field array.
+Instead, all raw pointers point into unmanaged heap memory, and are assumed to *always* alias one another:
+Each change brought to unmanaged memory by one raw pointer is assumed to affect the memory pointed to by
+another raw pointer.
+
+Classes
+-------
 
 .. autoclass:: pystencils.backend.symbols.PsSymbol
     :members:
 
-.. autoclass:: pystencils.backend.constants.PsConstant
+.. automodule:: pystencils.backend.arrays
     :members:
 
-.. autoclass:: pystencils.backend.literals.PsLiteral
+
+Constants and Literals
+======================
+
+.. autoclass:: pystencils.backend.constants.PsConstant
     :members:
 
-.. automodule:: pystencils.backend.arrays
+.. autoclass:: pystencils.backend.literals.PsLiteral
     :members:
diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py
index 5f9c95d5d..151f86c6e 100644
--- a/src/pystencils/backend/ast/expressions.py
+++ b/src/pystencils/backend/ast/expressions.py
@@ -1,8 +1,12 @@
 from __future__ import annotations
+
 from abc import ABC, abstractmethod
 from typing import Sequence, overload, Callable, Any, cast
 import operator
 
+import numpy as np
+from numpy.typing import NDArray
+
 from ..symbols import PsSymbol
 from ..constants import PsConstant
 from ..literals import PsLiteral
@@ -99,7 +103,8 @@ class PsExpression(PsAstNode, ABC):
 
 
 class PsLvalue(ABC):
-    """Mix-in for all expressions that may occur as an lvalue"""
+    """Mix-in for all expressions that may occur as an lvalue;
+    i.e. expressions that represent a memory location."""
 
 
 class PsSymbolExpr(PsLeafMixIn, PsLvalue, PsExpression):
@@ -189,48 +194,99 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression):
 
 
 class PsSubscript(PsLvalue, PsExpression):
-    __match_args__ = ("base", "index")
+    """N-dimensional subscript into an array."""
+
+    __match_args__ = ("array", "index")
 
-    def __init__(self, base: PsExpression, index: PsExpression):
+    def __init__(self, arr: PsExpression, index: Sequence[PsExpression]):
         super().__init__()
-        self._base = base
-        self._index = index
+        self._arr = arr
+
+        if not index:
+            raise ValueError("Subscript index cannot be empty.")
+
+        self._index = list(index)
 
     @property
-    def base(self) -> PsExpression:
-        return self._base
+    def array(self) -> PsExpression:
+        return self._arr
 
-    @base.setter
-    def base(self, expr: PsExpression):
-        self._base = expr
+    @array.setter
+    def array(self, expr: PsExpression):
+        self._arr = expr
 
     @property
-    def index(self) -> PsExpression:
+    def index(self) -> list[PsExpression]:
         return self._index
 
     @index.setter
-    def index(self, expr: PsExpression):
-        self._index = expr
+    def index(self, idx: Sequence[PsExpression]):
+        self._index = list(idx)
 
     def clone(self) -> PsSubscript:
-        return PsSubscript(self._base.clone(), self._index.clone())
+        return PsSubscript(self._arr.clone(), [i.clone() for i in self._index])
 
     def get_children(self) -> tuple[PsAstNode, ...]:
-        return (self._base, self._index)
+        return (self._arr,) + tuple(self._index)
+
+    def set_child(self, idx: int, c: PsAstNode):
+        idx = range(len(self._index) + 1)[idx]
+        match idx:
+            case 0:
+                self.array = failing_cast(PsExpression, c)
+            case _:
+                self.index[idx - 1] = failing_cast(PsExpression, c)
+
+    def __repr__(self) -> str:
+        idx = ", ".join(repr(i) for i in self._index)
+        return f"PsSubscript({self._arr}, ({idx}))"
+
+
+class PsMemAcc(PsLvalue, PsExpression):
+    """Pointer-based memory access with type-dependent offset."""
+
+    __match_args__ = ("pointer", "offset")
+
+    def __init__(self, ptr: PsExpression, offset: PsExpression):
+        super().__init__()
+        self._ptr = ptr
+        self._offset = offset
+
+    @property
+    def pointer(self) -> PsExpression:
+        return self._ptr
+
+    @pointer.setter
+    def pointer(self, expr: PsExpression):
+        self._ptr = expr
+
+    @property
+    def offset(self) -> PsExpression:
+        return self._offset
+
+    @offset.setter
+    def offset(self, expr: PsExpression):
+        self._offset = expr
+
+    def clone(self) -> PsMemAcc:
+        return PsMemAcc(self._ptr.clone(), self._offset.clone())
+
+    def get_children(self) -> tuple[PsAstNode, ...]:
+        return (self._ptr, self._offset)
 
     def set_child(self, idx: int, c: PsAstNode):
         idx = [0, 1][idx]
         match idx:
             case 0:
-                self.base = failing_cast(PsExpression, c)
+                self.pointer = failing_cast(PsExpression, c)
             case 1:
-                self.index = failing_cast(PsExpression, c)
+                self.offset = failing_cast(PsExpression, c)
 
     def __repr__(self) -> str:
-        return f"Subscript({self._base})[{self._index}]"
+        return f"PsMemAcc({repr(self._ptr)}, {repr(self._offset)})"
 
 
-class PsArrayAccess(PsSubscript):
+class PsArrayAccess(PsMemAcc):
     __match_args__ = ("base_ptr", "index")
 
     def __init__(self, base_ptr: PsArrayBasePointer, index: PsExpression):
@@ -243,11 +299,11 @@ class PsArrayAccess(PsSubscript):
         return self._base_ptr
 
     @property
-    def base(self) -> PsExpression:
-        return self._base
+    def pointer(self) -> PsExpression:
+        return self._ptr
 
-    @base.setter
-    def base(self, expr: PsExpression):
+    @pointer.setter
+    def pointer(self, expr: PsExpression):
         if not isinstance(expr, PsSymbolExpr) or not isinstance(
             expr.symbol, PsArrayBasePointer
         ):
@@ -256,17 +312,25 @@ class PsArrayAccess(PsSubscript):
             )
 
         self._base_ptr = expr.symbol
-        self._base = expr
+        self._ptr = expr
 
     @property
     def array(self) -> PsLinearizedArray:
         return self._base_ptr.array
+    
+    @property
+    def index(self) -> PsExpression:
+        return self._offset
+
+    @index.setter
+    def index(self, expr: PsExpression):
+        self._offset = expr
 
     def clone(self) -> PsArrayAccess:
-        return PsArrayAccess(self._base_ptr, self._index.clone())
+        return PsArrayAccess(self._base_ptr, self._offset.clone())
 
     def __repr__(self) -> str:
-        return f"ArrayAccess({repr(self._base_ptr)}, {repr(self._index)})"
+        return f"PsArrayAccess({repr(self._base_ptr)}, {repr(self._offset)})"
 
 
 class PsVectorArrayAccess(PsArrayAccess):
@@ -314,7 +378,7 @@ class PsVectorArrayAccess(PsArrayAccess):
     def clone(self) -> PsVectorArrayAccess:
         return PsVectorArrayAccess(
             self._base_ptr,
-            self._index.clone(),
+            self._offset.clone(),
             self.vector_entries,
             self._stride,
             self._alignment,
@@ -525,11 +589,16 @@ class PsNeg(PsUnOp, PsNumericOpTrait):
         return operator.neg
 
 
-class PsDeref(PsLvalue, PsUnOp):
-    pass
-
-
 class PsAddressOf(PsUnOp):
+    """Take the address of a memory location.
+    
+    .. DANGER::
+        Taking the address of a memory location owned by a symbol or field array
+        introduces an alias to that memory location.
+        As pystencils assumes its symbols and fields to never be aliased, this can
+        subtly change the semantics of a kernel. 
+        Use the address-of operator with utmost care.
+    """
     pass
 
 
@@ -740,32 +809,44 @@ class PsLt(PsRel):
 
 
 class PsArrayInitList(PsExpression):
+    """N-dimensional array initialization matrix."""
+
     __match_args__ = ("items",)
 
-    def __init__(self, items: Sequence[PsExpression]):
+    def __init__(self, items: Sequence[PsExpression | Sequence[PsExpression | Sequence[PsExpression]]]):
         super().__init__()
-        self._items = list(items)
+        self._items = np.array(items, dtype=np.object_)
 
     @property
-    def items(self) -> list[PsExpression]:
+    def items_grid(self) -> NDArray[np.object_]:
         return self._items
+    
+    @property
+    def shape(self) -> tuple[int, ...]:
+        return self._items.shape
+    
+    @property
+    def items(self) -> tuple[PsExpression, ...]:
+        return tuple(self._items.flat)  # type: ignore
 
     def get_children(self) -> tuple[PsAstNode, ...]:
-        return tuple(self._items)
+        return tuple(self._items.flat)  # type: ignore
 
     def set_child(self, idx: int, c: PsAstNode):
-        self._items[idx] = failing_cast(PsExpression, c)
+        self._items.flat[idx] = failing_cast(PsExpression, c)
 
     def clone(self) -> PsExpression:
-        return PsArrayInitList([expr.clone() for expr in self._items])
+        return PsArrayInitList(  
+            np.array([expr.clone() for expr in self.children]).reshape(  # type: ignore
+                self._items.shape
+            )
+        )
 
     def __repr__(self) -> str:
         return f"PsArrayInitList({repr(self._items)})"
 
 
-def evaluate_expression(
-    expr: PsExpression, valuation: dict[str, Any]
-) -> Any:
+def evaluate_expression(expr: PsExpression, valuation: dict[str, Any]) -> Any:
     """Evaluate a pystencils backend expression tree with values assigned to symbols according to the given valuation.
 
     Only a subset of expression nodes can be processed by this evaluator.
diff --git a/src/pystencils/backend/ast/util.py b/src/pystencils/backend/ast/util.py
index 0d3b78629..72aff0a01 100644
--- a/src/pystencils/backend/ast/util.py
+++ b/src/pystencils/backend/ast/util.py
@@ -1,8 +1,15 @@
 from __future__ import annotations
-from typing import Any, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING, cast
+
+from ..exceptions import PsInternalCompilerError
+from ..symbols import PsSymbol
+from ..arrays import PsLinearizedArray
+from ...types import PsDereferencableType
+
 
 if TYPE_CHECKING:
     from .astnode import PsAstNode
+    from .expressions import PsExpression
 
 
 def failing_cast(target: type | tuple[type, ...], obj: Any) -> Any:
@@ -36,3 +43,43 @@ class AstEqWrapper:
         #   TODO: consider replacing this with smth. more performant
         #   TODO: Check that repr is implemented by all AST nodes
         return hash(repr(self._node))
+
+
+def determine_memory_object(
+    expr: PsExpression,
+) -> tuple[PsSymbol | PsLinearizedArray | None, bool]:
+    """Return the memory object accessed by the given expression, together with its constness
+
+    Returns:
+        Tuple ``(mem_obj, const)`` identifying the memory object accessed by the given expression,
+        as well as its constness
+    """
+    from pystencils.backend.ast.expressions import (
+        PsSubscript,
+        PsLookup,
+        PsSymbolExpr,
+        PsMemAcc,
+        PsArrayAccess,
+    )
+
+    while isinstance(expr, (PsSubscript, PsLookup)):
+        match expr:
+            case PsSubscript(arr, _):
+                expr = arr
+            case PsLookup(record, _):
+                expr = record
+
+    match expr:
+        case PsSymbolExpr(symb):
+            return symb, symb.get_dtype().const
+        case PsMemAcc(ptr, _):
+            return None, cast(PsDereferencableType, ptr.get_dtype()).base_type.const
+        case PsArrayAccess(ptr, _):
+            return (
+                expr.array,
+                cast(PsDereferencableType, ptr.get_dtype()).base_type.const,
+            )
+        case _:
+            raise PsInternalCompilerError(
+                "The given expression is a transient and does not refer to a memory object"
+            )
diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py
index 8928cc689..579d47648 100644
--- a/src/pystencils/backend/emission.py
+++ b/src/pystencils/backend/emission.py
@@ -16,6 +16,7 @@ from .ast.structural import (
 )
 
 from .ast.expressions import (
+    PsExpression,
     PsAdd,
     PsAddressOf,
     PsArrayInitList,
@@ -26,7 +27,7 @@ from .ast.expressions import (
     PsCall,
     PsCast,
     PsConstantExpr,
-    PsDeref,
+    PsMemAcc,
     PsDiv,
     PsRem,
     PsIntDiv,
@@ -36,7 +37,6 @@ from .ast.expressions import (
     PsNeg,
     PsRightShift,
     PsSub,
-    PsSubscript,
     PsSymbolExpr,
     PsLiteralExpr,
     PsVectorArrayAccess,
@@ -50,6 +50,7 @@ from .ast.expressions import (
     PsLt,
     PsGe,
     PsLe,
+    PsSubscript
 )
 
 from .extensions.foreign_ast import PsForeignExpression
@@ -270,16 +271,27 @@ class CAstPrinter:
             case PsVectorArrayAccess():
                 raise EmissionError("Cannot print vectorized array accesses")
 
-            case PsSubscript(base, index):
+            case PsMemAcc(base, offset):
                 pc.push_op(Ops.Subscript, LR.Left)
                 base_code = self.visit(base, pc)
                 pc.pop_op()
 
                 pc.push_op(Ops.Weakest, LR.Middle)
-                index_code = self.visit(index, pc)
+                index_code = self.visit(offset, pc)
                 pc.pop_op()
 
                 return pc.parenthesize(f"{base_code}[{index_code}]", Ops.Subscript)
+            
+            case PsSubscript(base, indices):
+                pc.push_op(Ops.Subscript, LR.Left)
+                base_code = self.visit(base, pc)
+                pc.pop_op()
+
+                pc.push_op(Ops.Weakest, LR.Middle)
+                indices_code = "".join("[" + self.visit(idx, pc) + "]" for idx in indices)
+                pc.pop_op()
+
+                return pc.parenthesize(base_code + indices_code, Ops.Subscript)
 
             case PsLookup(aggr, member_name):
                 pc.push_op(Ops.Lookup, LR.Left)
@@ -320,12 +332,12 @@ class CAstPrinter:
 
                 return pc.parenthesize(f"!{operand_code}", Ops.Not)
 
-            case PsDeref(operand):
-                pc.push_op(Ops.Deref, LR.Right)
-                operand_code = self.visit(operand, pc)
-                pc.pop_op()
+            # case PsDeref(operand):
+            #     pc.push_op(Ops.Deref, LR.Right)
+            #     operand_code = self.visit(operand, pc)
+            #     pc.pop_op()
 
-                return pc.parenthesize(f"*{operand_code}", Ops.Deref)
+            #     return pc.parenthesize(f"*{operand_code}", Ops.Deref)
 
             case PsAddressOf(operand):
                 pc.push_op(Ops.AddressOf, LR.Right)
@@ -355,11 +367,19 @@ class CAstPrinter:
                     f"{cond_code} ? {then_code} : {else_code}", Ops.Ternary
                 )
 
-            case PsArrayInitList(items):
+            case PsArrayInitList(_):
+                def print_arr(item) -> str:
+                    if isinstance(item, PsExpression):
+                        return self.visit(item, pc)
+                    else:
+                        #   it's a subarray
+                        entries = ", ".join(print_arr(i) for i in item)
+                        return "{ " + entries + " }"
+
                 pc.push_op(Ops.Weakest, LR.Middle)
-                items_str = ", ".join(self.visit(item, pc) for item in items)
+                arr_str = print_arr(node.items_grid)
                 pc.pop_op()
-                return "{ " + items_str + " }"
+                return arr_str
 
             case PsForeignExpression(children):
                 pc.push_op(Ops.Weakest, LR.Middle)
@@ -379,10 +399,11 @@ class CAstPrinter:
     def _symbol_decl(self, symb: PsSymbol):
         dtype = symb.get_dtype()
 
-        array_dims = []
-        while isinstance(dtype, PsArrayType):
-            array_dims.append(dtype.length)
+        if isinstance(dtype, PsArrayType):
+            array_dims = dtype.shape 
             dtype = dtype.base_type
+        else:
+            array_dims = ()
 
         code = f"{dtype.c_string()} {symb.name}"
         for d in array_dims:
diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py
index a0328a123..0dc60b1b1 100644
--- a/src/pystencils/backend/kernelcreation/ast_factory.py
+++ b/src/pystencils/backend/kernelcreation/ast_factory.py
@@ -12,7 +12,7 @@ from ..symbols import PsSymbol
 from ..constants import PsConstant
 
 from .context import KernelCreationContext
-from .freeze import FreezeExpressions
+from .freeze import FreezeExpressions, ExprLike
 from .typification import Typifier
 from .iteration_space import FullIterationSpace
 
@@ -42,14 +42,14 @@ class AstFactory:
         pass
 
     @overload
-    def parse_sympy(self, sp_obj: sp.Expr) -> PsExpression:
+    def parse_sympy(self, sp_obj: ExprLike) -> PsExpression:
         pass
 
     @overload
     def parse_sympy(self, sp_obj: AssignmentBase) -> PsAssignment:
         pass
 
-    def parse_sympy(self, sp_obj: sp.Expr | AssignmentBase) -> PsAstNode:
+    def parse_sympy(self, sp_obj: ExprLike | AssignmentBase) -> PsAstNode:
         """Parse a SymPy expression or assignment through `FreezeExpressions` and `Typifier`.
 
         The expression or assignment will be typified in a numerical context, using the kernel
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index ff7754ac2..0ae5a0d1b 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -14,7 +14,7 @@ from ...sympyextensions import (
     ConditionalFieldAccess,
 )
 from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType
-from ...sympyextensions.pointers import AddressOf
+from ...sympyextensions.pointers import AddressOf, mem_acc
 from ...field import Field, FieldType
 
 from .context import KernelCreationContext
@@ -55,6 +55,7 @@ from ..ast.expressions import (
     PsAnd,
     PsOr,
     PsNot,
+    PsMemAcc
 )
 
 from ..constants import PsConstant
@@ -275,16 +276,36 @@ class FreezeExpressions:
         return PsSymbolExpr(symb)
 
     def map_Tuple(self, expr: sp.Tuple) -> PsArrayInitList:
+        if not expr:
+            raise FreezeError("Cannot translate an empty tuple.")
+
         items = [self.visit_expr(item) for item in expr]
-        return PsArrayInitList(items)
+        
+        if any(isinstance(i, PsArrayInitList) for i in items):
+            #  base case: have nested arrays
+            if not all(isinstance(i, PsArrayInitList) for i in items):
+                raise FreezeError(
+                    f"Cannot translate nested arrays of non-uniform shape: {expr}"
+                )
+            
+            subarrays = cast(list[PsArrayInitList], items)
+            shape_tail = subarrays[0].shape
+            
+            if not all(s.shape == shape_tail for s in subarrays[1:]):
+                raise FreezeError(
+                    f"Cannot translate nested arrays of non-uniform shape: {expr}"
+                )
+            
+            return PsArrayInitList([s.items_grid for s in subarrays])  # type: ignore
+        else:
+            #  base case: no nested arrays
+            return PsArrayInitList(items)
 
     def map_Indexed(self, expr: sp.Indexed) -> PsSubscript:
         assert isinstance(expr.base, sp.IndexedBase)
         base = self.visit_expr(expr.base.label)
-        subscript = PsSubscript(base, self.visit_expr(expr.indices[0]))
-        for idx in expr.indices[1:]:
-            subscript = PsSubscript(subscript, self.visit_expr(idx))
-        return subscript
+        indices = [self.visit_expr(i) for i in expr.indices]
+        return PsSubscript(base, indices)
 
     def map_Access(self, access: Field.Access):
         field = access.field
@@ -429,6 +450,8 @@ class FreezeExpressions:
                 )
             case AddressOf():
                 return PsAddressOf(*args)
+            case mem_acc():
+                return PsMemAcc(*args)
             case _:
                 raise FreezeError(f"Unsupported function: {func}")
 
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index ce2d24f98..819d4a12b 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -1,6 +1,6 @@
 from __future__ import annotations
 
-from typing import TypeVar
+from typing import TypeVar, Callable
 
 from .context import KernelCreationContext
 from ...types import (
@@ -35,11 +35,11 @@ from ..ast.expressions import (
     PsCall,
     PsTernary,
     PsCast,
-    PsDeref,
     PsAddressOf,
     PsConstantExpr,
     PsLookup,
     PsSubscript,
+    PsMemAcc,
     PsSymbolExpr,
     PsLiteralExpr,
     PsRel,
@@ -47,6 +47,7 @@ from ..ast.expressions import (
     PsNot,
 )
 from ..functions import PsMathFunction, CFunction
+from ..ast.util import determine_memory_object
 
 __all__ = ["Typifier"]
 
@@ -57,38 +58,40 @@ class TypificationError(Exception):
 
 NodeT = TypeVar("NodeT", bound=PsAstNode)
 
+ResolutionHook = Callable[[PsType], None]
+
 
 class TypeContext:
     """Typing context, with support for type inference and checking.
 
     Instances of this class are used to propagate and check data types across expression subtrees
-    of the AST. Each type context has:
-
-     - A target type `target_type`, which shall be applied to all expressions it covers
-     - A set of restrictions on the target type:
-       - `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides
-       - Additional restrictions may be added in the future.
+    of the AST. Each type context has a target type `target_type`, which shall be applied to all expressions it covers
     """
 
     def __init__(
         self,
         target_type: PsType | None = None,
-        require_nonconst: bool = False,
     ):
-        self._require_nonconst = require_nonconst
         self._deferred_exprs: list[PsExpression] = []
 
-        self._target_type = (
-            self._fix_constness(target_type) if target_type is not None else None
-        )
+        self._target_type = deconstify(target_type) if target_type is not None else None
+
+        self._hooks: list[ResolutionHook] = []
 
     @property
     def target_type(self) -> PsType | None:
         return self._target_type
 
-    @property
-    def require_nonconst(self) -> bool:
-        return self._require_nonconst
+    def add_hook(self, hook: ResolutionHook):
+        """Adds a resolution hook to this context.
+
+        The hook will be called with the context's target type as soon as it becomes known,
+        which might be immediately.
+        """
+        if self._target_type is None:
+            self._hooks.append(hook)
+        else:
+            hook(self._target_type)
 
     def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None):
         """Applies the given ``dtype`` to this type context, and optionally to the given expression.
@@ -102,7 +105,7 @@ class TypeContext:
         and will be replaced by it.
         """
 
-        dtype = self._fix_constness(dtype)
+        dtype = deconstify(dtype)
 
         if self._target_type is not None and dtype != self._target_type:
             raise TypificationError(
@@ -134,6 +137,12 @@ class TypeContext:
             self._apply_target_type(expr)
 
     def _propagate_target_type(self):
+        assert self._target_type is not None
+
+        for hook in self._hooks:
+            hook(self._target_type)
+        self._hooks = []
+
         for expr in self._deferred_exprs:
             self._apply_target_type(expr)
         self._deferred_exprs = []
@@ -211,30 +220,10 @@ class TypeContext:
 
     def _compatible(self, dtype: PsType):
         """Checks whether the given data type is compatible with the context's target type.
-
-        If the target type is ``const``, they must be equal up to const qualification;
-        if the target type is not ``const``, `dtype` must match it exactly.
+        The two must match except for constness.
         """
         assert self._target_type is not None
-        if self._target_type.const:
-            return constify(dtype) == self._target_type
-        else:
-            return dtype == self._target_type
-
-    def _fix_constness(self, dtype: PsType, expr: PsExpression | None = None):
-        if self._require_nonconst:
-            if dtype.const:
-                if expr is None:
-                    raise TypificationError(
-                        f"Type mismatch: Encountered {dtype} in non-constant context."
-                    )
-                else:
-                    raise TypificationError(
-                        f"Type mismatch at expression {expr}: Encountered {dtype} in non-constant context."
-                    )
-            return dtype
-        else:
-            return constify(dtype)
+        return deconstify(dtype) == self._target_type
 
 
 class Typifier:
@@ -276,11 +265,7 @@ class Typifier:
 
     **Typing of symbol expressions**
 
-    Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types, but
-    not necessarily their const-qualification.
-    A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type,
-    and an array base pointer with non-``const`` base type may be nested in a ``const`` `PsArrayAccess`,
-    but not vice versa.
+    Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types.
     """
 
     def __init__(self, ctx: KernelCreationContext):
@@ -316,7 +301,44 @@ class Typifier:
                 for s in statements:
                     self.visit(s)
 
-            case PsDeclaration(lhs, rhs):
+            case PsDeclaration(lhs, rhs) if isinstance(rhs, PsArrayInitList):
+                #   Special treatment for array declarations
+                assert isinstance(lhs, PsSymbolExpr)
+
+                decl_tc = TypeContext()
+                items_tc = TypeContext()
+
+                if (lhs_type := lhs.symbol.dtype) is not None:
+                    if not isinstance(lhs_type, PsArrayType):
+                        raise TypificationError(
+                            f"Illegal LHS type in array declaration: {lhs_type}"
+                        )
+
+                    if lhs_type.shape != rhs.shape:
+                        raise TypificationError(
+                            f"Incompatible shapes in declaration of array symbol {lhs.symbol}.\n"
+                            f"  Symbol shape: {lhs_type.shape}\n"
+                            f"   Array shape: {rhs.shape}"
+                        )
+
+                    items_tc.apply_dtype(lhs_type.base_type)
+                    decl_tc.apply_dtype(lhs_type, lhs)
+                else:
+                    decl_tc.infer_dtype(lhs)
+
+                for item in rhs.items:
+                    self.visit_expr(item, items_tc)
+
+                if items_tc.target_type is None:
+                    items_tc.apply_dtype(self._ctx.default_dtype)
+
+                if decl_tc.target_type is None:
+                    assert items_tc.target_type is not None
+                    decl_tc.apply_dtype(
+                        PsArrayType(items_tc.target_type, rhs.shape), rhs
+                    )
+
+            case PsDeclaration(lhs, rhs) | PsAssignment(lhs, rhs):
                 #   Only if the LHS is an untyped symbol, infer its type from the RHS
                 infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None
 
@@ -333,27 +355,11 @@ class Typifier:
                 if infer_lhs and tc.target_type is None:
                     #   no type has been inferred -> use the default dtype
                     tc.apply_dtype(self._ctx.default_dtype)
-
-            case PsAssignment(lhs, rhs):
-                infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None
-
-                tc_lhs = TypeContext(require_nonconst=True)
-
-                if infer_lhs:
-                    tc_lhs.infer_dtype(lhs)
-                else:
-                    self.visit_expr(lhs, tc_lhs)
-                    assert tc_lhs.target_type is not None
-
-                tc_rhs = TypeContext(target_type=tc_lhs.target_type)
-                self.visit_expr(rhs, tc_rhs)
-
-                if infer_lhs:
-                    if tc_rhs.target_type is None:
-                        tc_rhs.apply_dtype(self._ctx.default_dtype)
-                    
-                    assert tc_rhs.target_type is not None
-                    tc_lhs.apply_dtype(deconstify(tc_rhs.target_type))
+                elif not isinstance(node, PsDeclaration):
+                    #   check mutability of LHS
+                    _, lhs_const = determine_memory_object(lhs)
+                    if lhs_const:
+                        raise TypificationError(f"Cannot assign to immutable LHS {lhs}")
 
             case PsConditional(cond, branch_true, branch_false):
                 cond_tc = TypeContext(PsBoolType())
@@ -409,49 +415,56 @@ class Typifier:
 
             case PsArrayAccess(bptr, idx):
                 tc.apply_dtype(bptr.array.element_type, expr)
+                self._handle_idx(idx)
+
+            case PsMemAcc(ptr, offset):
+                ptr_tc = TypeContext()
+                self.visit_expr(ptr, ptr_tc)
 
-                index_tc = TypeContext()
-                self.visit_expr(idx, index_tc)
-                if index_tc.target_type is None:
-                    index_tc.apply_dtype(self._ctx.index_dtype, idx)
-                elif not isinstance(index_tc.target_type, PsIntegerType):
+                if not isinstance(ptr_tc.target_type, PsPointerType):
                     raise TypificationError(
-                        f"Array index is not of integer type: {idx} has type {index_tc.target_type}"
+                        f"Type of pointer argument to memory access was not a pointer type: {ptr_tc.target_type}"
                     )
 
-            case PsSubscript(arr, idx):
-                arr_tc = TypeContext()
-                self.visit_expr(arr, arr_tc)
+                tc.apply_dtype(ptr_tc.target_type.base_type, expr)
+                self._handle_idx(offset)
 
-                if not isinstance(arr_tc.target_type, PsDereferencableType):
-                    raise TypificationError(
-                        "Type of subscript base is not subscriptable."
-                    )
+            case PsSubscript(arr, indices):
+                if isinstance(arr, PsArrayInitList):
+                    shape = arr.shape
 
-                tc.apply_dtype(arr_tc.target_type.base_type, expr)
+                    #   extend outer context over the init-list entries
+                    for item in arr.items:
+                        self.visit_expr(item, tc)
 
-                index_tc = TypeContext()
-                self.visit_expr(idx, index_tc)
-                if index_tc.target_type is None:
-                    index_tc.apply_dtype(self._ctx.index_dtype, idx)
-                elif not isinstance(index_tc.target_type, PsIntegerType):
-                    raise TypificationError(
-                        f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}"
-                    )
+                    #   learn the array type from the items
+                    def arr_hook(element_type: PsType):
+                        arr.dtype = PsArrayType(element_type, arr.shape)
 
-            case PsDeref(ptr):
-                ptr_tc = TypeContext()
-                self.visit_expr(ptr, ptr_tc)
+                    tc.add_hook(arr_hook)
+                else:
+                    #   type of array has to be known
+                    arr_tc = TypeContext()
+                    self.visit_expr(arr, arr_tc)
 
-                if not isinstance(ptr_tc.target_type, PsDereferencableType):
+                    if not isinstance(arr_tc.target_type, PsArrayType):
+                        raise TypificationError(
+                            f"Type of array argument to subscript was not an array type: {arr_tc.target_type}"
+                        )
+
+                    tc.apply_dtype(arr_tc.target_type.base_type, expr)
+                    shape = arr_tc.target_type.shape
+
+                if len(indices) != len(shape):
                     raise TypificationError(
-                        "Type of argument to a Deref is not dereferencable"
+                        f"Invalid number of indices to {len(shape)}-dimensional array: {len(indices)}"
                     )
 
-                tc.apply_dtype(ptr_tc.target_type.base_type, expr)
+                for idx in indices:
+                    self._handle_idx(idx)
 
             case PsAddressOf(arg):
-                if not isinstance(arg, (PsSymbolExpr, PsSubscript, PsDeref, PsLookup)):
+                if not isinstance(arg, (PsSymbolExpr, PsSubscript, PsMemAcc, PsLookup)):
                     raise TypificationError(
                         f"Illegal expression below AddressOf operator: {arg}"
                     )
@@ -468,8 +481,8 @@ class Typifier:
                 match arg:
                     case PsSymbolExpr(s):
                         pointed_to_type = s.get_dtype()
-                    case PsSubscript(arr, _) | PsDeref(arr):
-                        arr_type = arr.get_dtype()
+                    case PsSubscript(ptr, _) | PsMemAcc(ptr, _):
+                        arr_type = ptr.get_dtype()
                         assert isinstance(arr_type, PsDereferencableType)
                         pointed_to_type = arr_type.base_type
                     case PsLookup(aggr, member_name):
@@ -491,7 +504,7 @@ class Typifier:
 
             case PsLookup(aggr, member_name):
                 #   Members of a struct type inherit the struct type's `const` qualifier
-                aggr_tc = TypeContext(None, require_nonconst=tc.require_nonconst)
+                aggr_tc = TypeContext()
                 self.visit_expr(aggr, aggr_tc)
                 aggr_type = aggr_tc.target_type
 
@@ -566,32 +579,11 @@ class Typifier:
                             f"Don't know how to typify calls to {function}"
                         )
 
-            case PsArrayInitList(items):
-                items_tc = TypeContext()
-                for item in items:
-                    self.visit_expr(item, items_tc)
-
-                if items_tc.target_type is None:
-                    if tc.target_type is None:
-                        raise TypificationError(f"Unable to infer type of array {expr}")
-                    elif not isinstance(tc.target_type, PsArrayType):
-                        raise TypificationError(
-                            f"Cannot apply type {tc.target_type} to an array initializer."
-                        )
-                    elif (
-                        tc.target_type.length is not None
-                        and tc.target_type.length != len(items)
-                    ):
-                        raise TypificationError(
-                            "Array size mismatch: Cannot typify initializer list with "
-                            f"{len(items)} items as {tc.target_type}"
-                        )
-                    else:
-                        items_tc.apply_dtype(tc.target_type.base_type)
-                        tc.infer_dtype(expr)
-                else:
-                    arr_type = PsArrayType(items_tc.target_type, len(items))
-                    tc.apply_dtype(arr_type, expr)
+            case PsArrayInitList(_):
+                raise TypificationError(
+                    "Unable to typify array initializer in isolation.\n"
+                    f"    Array: {expr}"
+                )
 
             case PsCast(dtype, arg):
                 arg_tc = TypeContext()
@@ -606,3 +598,16 @@ class Typifier:
 
             case _:
                 raise NotImplementedError(f"Can't typify {expr}")
+
+    def _handle_idx(self, idx: PsExpression):
+        index_tc = TypeContext()
+        self.visit_expr(idx, index_tc)
+
+        if index_tc.target_type is None:
+            index_tc.apply_dtype(self._ctx.index_dtype, idx)
+        elif not isinstance(index_tc.target_type, PsIntegerType):
+            raise TypificationError(
+                f"Invalid data type in index expression.\n"
+                f"    Expression: {idx}\n"
+                f"          Type: {index_tc.target_type}"
+            )
diff --git a/src/pystencils/backend/literals.py b/src/pystencils/backend/literals.py
index dc254da0e..976e6b203 100644
--- a/src/pystencils/backend/literals.py
+++ b/src/pystencils/backend/literals.py
@@ -6,7 +6,7 @@ class PsLiteral:
     """Representation of literal code.
 
     Instances of this class represent code literals inside the AST.
-    These literals are not to be confused with C literals; the name `Literal` refers to the fact that
+    These literals are not to be confused with C literals; the name 'Literal' refers to the fact that
     the code generator takes them "literally", printing them as they are.
 
     Each literal has to be annotated with a type, and is considered constant within the scope of a kernel.
diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py
index 39ec09992..fa42ed021 100644
--- a/src/pystencils/backend/platforms/sycl.py
+++ b/src/pystencils/backend/platforms/sycl.py
@@ -121,7 +121,7 @@ class SyclPlatform(GenericGpu):
         for i, dim in enumerate(dimensions):
             #   Slowest to fastest
             coord = PsExpression.make(PsConstant(i, self._ctx.index_dtype))
-            work_item_idx = PsSubscript(id_symbol, coord)
+            work_item_idx = PsSubscript(id_symbol, (coord,))
 
             dim.counter.dtype = constify(dim.counter.get_dtype())
             work_item_idx.dtype = dim.counter.get_dtype()
@@ -151,7 +151,7 @@ class SyclPlatform(GenericGpu):
         id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type))
 
         zero = PsExpression.make(PsConstant(0, self._ctx.index_dtype))
-        subscript = PsSubscript(id_symbol, zero)
+        subscript = PsSubscript(id_symbol, (zero,))
 
         ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
         subscript.dtype = ispace.sparse_counter.get_dtype()
diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py
index ccaf9fbe9..0c6f6883d 100644
--- a/src/pystencils/backend/platforms/x86.py
+++ b/src/pystencils/backend/platforms/x86.py
@@ -7,7 +7,7 @@ from ..ast.expressions import (
     PsExpression,
     PsVectorArrayAccess,
     PsAddressOf,
-    PsSubscript,
+    PsMemAcc,
 )
 from ..transformations.select_intrinsics import IntrinsicOps
 from ...types import PsCustomType, PsVectorType, PsPointerType
@@ -145,7 +145,7 @@ class X86VectorCpu(GenericVectorCpu):
         if acc.stride == 1:
             load_func = _x86_packed_load(self._vector_arch, acc.dtype, False)
             return load_func(
-                PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index))
+                PsAddressOf(PsMemAcc(PsExpression.make(acc.base_ptr), acc.index))
             )
         else:
             raise NotImplementedError("Gather loads not implemented yet.")
@@ -154,7 +154,7 @@ class X86VectorCpu(GenericVectorCpu):
         if acc.stride == 1:
             store_func = _x86_packed_store(self._vector_arch, acc.dtype, False)
             return store_func(
-                PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)),
+                PsAddressOf(PsMemAcc(PsExpression.make(acc.base_ptr), acc.index)),
                 arg,
             )
         else:
diff --git a/src/pystencils/backend/transformations/erase_anonymous_structs.py b/src/pystencils/backend/transformations/erase_anonymous_structs.py
index 03d79a689..7404abd94 100644
--- a/src/pystencils/backend/transformations/erase_anonymous_structs.py
+++ b/src/pystencils/backend/transformations/erase_anonymous_structs.py
@@ -8,7 +8,7 @@ from ..ast.expressions import (
     PsArrayAccess,
     PsLookup,
     PsExpression,
-    PsDeref,
+    PsMemAcc,
     PsAddressOf,
     PsCast,
 )
@@ -99,8 +99,9 @@ class EraseAnonymousStructTypes:
         )
         type_erased_access = PsArrayAccess(type_erased_bp, byte_index)
 
-        deref = PsDeref(
-            PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access))
+        deref = PsMemAcc(
+            PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access)),
+            PsExpression.make(PsConstant(0))
         )
 
         typify = Typifier(self._ctx)
diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
index 2368868a9..d4dfd3d04 100644
--- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
+++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
@@ -9,12 +9,14 @@ from ..ast.expressions import (
     PsConstantExpr,
     PsLiteralExpr,
     PsCall,
-    PsDeref,
+    PsArrayAccess,
     PsSubscript,
+    PsLookup,
     PsUnOp,
     PsBinOp,
     PsArrayInitList,
 )
+from ..ast.util import determine_memory_object
 
 from ...types import PsDereferencableType
 from ..symbols import PsSymbol
@@ -48,7 +50,11 @@ class HoistContext:
             case PsCall(func):
                 return isinstance(func, PsMathFunction) and args_invariant(expr)
 
-            case PsSubscript(ptr, _) | PsDeref(ptr):
+            case PsSubscript() | PsLookup():
+                return determine_memory_object(expr)[1] and args_invariant(expr)
+
+            case PsArrayAccess(ptr, _):
+                #   Regular pointer derefs are never invariant, since we cannot reason about aliasing
                 ptr_type = cast(PsDereferencableType, ptr.get_dtype())
                 return ptr_type.base_type.const and args_invariant(expr)
 
diff --git a/src/pystencils/sympyextensions/pointers.py b/src/pystencils/sympyextensions/pointers.py
index c69f9376d..2ebeba7c9 100644
--- a/src/pystencils/sympyextensions/pointers.py
+++ b/src/pystencils/sympyextensions/pointers.py
@@ -31,3 +31,14 @@ class AddressOf(sp.Function):
             return PsPointerType(arg_type, restrict=True, const=True)
         else:
             raise ValueError(f'pystencils supports only non void pointers. Current address_of type: {self.args[0]}')
+
+
+class mem_acc(sp.Function):
+    """Memory access through a raw pointer with an offset.
+    
+    This function should be used to model offset memory accesses through raw pointers.
+    """
+    
+    @classmethod
+    def eval(cls, ptr, offset):
+        return None
diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py
index 61e3d73fd..e6fc4bb78 100644
--- a/src/pystencils/types/types.py
+++ b/src/pystencils/types/types.py
@@ -1,6 +1,6 @@
 from __future__ import annotations
 from abc import ABC, abstractmethod
-from typing import final, Any, Sequence
+from typing import final, Any, Sequence, SupportsIndex
 from dataclasses import dataclass
 
 import numpy as np
@@ -98,31 +98,57 @@ class PsPointerType(PsDereferencableType):
 
 
 class PsArrayType(PsDereferencableType):
-    """C array type of known or unknown size."""
+    """Multidimensional array of fixed shape.
+    
+    The element type of an array is never const; only the array itself can be.
+    If ``element_type`` is const, its constness will be removed.
+    """
 
     def __init__(
-        self, base_type: PsType, length: int | None = None, const: bool = False
+        self, element_type: PsType, shape: SupportsIndex | Sequence[SupportsIndex], const: bool = False
     ):
-        self._length = length
-        super().__init__(base_type, const)
+        from operator import index
+        if isinstance(shape, SupportsIndex):
+            shape = (index(shape),)
+        else:
+            shape = tuple(index(s) for s in shape)
+
+        if not shape or any(s <= 0 for s in shape):
+            raise ValueError(f"Invalid array shape: {shape}")
+        
+        if isinstance(element_type, PsArrayType):
+            raise ValueError("Element type of array cannot be another array.")
+        
+        element_type = deconstify(element_type)
+
+        self._shape = shape
+        super().__init__(element_type, const)
 
     def __args__(self) -> tuple[Any, ...]:
         """
-        >>> t = PsArrayType(PsBoolType(), 13)
+        >>> t = PsArrayType(PsBoolType(), (13, 42))
         >>> t == PsArrayType(*t.__args__())
         True
         """
-        return (self._base_type, self._length)
+        return (self._base_type, self._shape)
 
     @property
-    def length(self) -> int | None:
-        return self._length
+    def shape(self) -> tuple[int, ...]:
+        """Shape of this array"""
+        return self._shape
+    
+    @property
+    def dim(self) -> int:
+        """Dimensionality of this array"""
+        return len(self._shape)
 
     def c_string(self) -> str:
-        return f"{self._base_type.c_string()} [{str(self._length) if self._length is not None else ''}]"
+        arr_brackets = "".join(f"[{s}]" for s in self._shape)
+        const = self._const_string()
+        return const + self._base_type.c_string() + arr_brackets
 
     def __repr__(self) -> str:
-        return f"PsArrayType(element_type={repr(self._base_type)}, size={self._length}, const={self._const})"
+        return f"PsArrayType(element_type={repr(self._base_type)}, shape={self._shape}, const={self._const})"
 
 
 class PsStructType(PsType):
@@ -131,6 +157,8 @@ class PsStructType(PsType):
     A struct type is defined by its sequence of members.
     The struct may optionally have a name, although the code generator currently does not support named structs
     and treats them the same way as anonymous structs.
+
+    Struct member types cannot be ``const``; if a ``const`` member type is passed, its constness will be removed.
     """
 
     @dataclass(frozen=True)
@@ -138,6 +166,10 @@ class PsStructType(PsType):
         name: str
         dtype: PsType
 
+        def __post_init__(self):
+            #   Need to use object.__setattr__ because instances are frozen
+            object.__setattr__(self, "dtype", deconstify(self.dtype))
+
     @staticmethod
     def _canonical_members(members: Sequence[PsStructType.Member | tuple[str, PsType]]):
         return tuple(
diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py
index 072882a7b..341b75601 100644
--- a/tests/nbackend/kernelcreation/test_freeze.py
+++ b/tests/nbackend/kernelcreation/test_freeze.py
@@ -10,6 +10,7 @@ from pystencils import (
     DynamicType,
 )
 from pystencils.sympyextensions import CastFunc
+from pystencils.sympyextensions.pointers import mem_acc
 
 from pystencils.backend.ast.structural import (
     PsAssignment,
@@ -40,6 +41,9 @@ from pystencils.backend.ast.expressions import (
     PsAdd,
     PsMul,
     PsSub,
+    PsArrayInitList,
+    PsSubscript,
+    PsMemAcc,
 )
 from pystencils.backend.constants import PsConstant
 from pystencils.backend.functions import PsMathFunction, MathFunctions
@@ -353,3 +357,100 @@ def test_add_sub():
 
     expr = freeze(x - 2 * y)
     assert expr.structurally_equal(PsAdd(x2, PsMul(minus_two, y2)))
+
+
+def test_tuple_array_literals():
+    ctx = KernelCreationContext()
+    freeze = FreezeExpressions(ctx)
+
+    x, y, z = sp.symbols("x, y, z")
+
+    x2 = PsExpression.make(ctx.get_symbol("x"))
+    y2 = PsExpression.make(ctx.get_symbol("y"))
+    z2 = PsExpression.make(ctx.get_symbol("z"))
+
+    one = PsExpression.make(PsConstant(1))
+    three = PsExpression.make(PsConstant(3))
+    four = PsExpression.make(PsConstant(4))
+
+    arr_literal = freeze(sp.Tuple(3 + y, z, z / 4))
+    assert arr_literal.structurally_equal(
+        PsArrayInitList([three + y2, z2, one / four * z2])
+    )
+
+
+def test_nested_tuples():
+    ctx = KernelCreationContext()
+    freeze = FreezeExpressions(ctx)
+
+    def f(n):
+        return freeze(sp.sympify(n))
+
+    shape = (2, 3, 2)
+    symb_arr = sp.Tuple(((1, 2), (3, 4), (5, 6)), ((5, 6), (7, 8), (9, 10)))
+    arr_literal = freeze(symb_arr)
+
+    assert isinstance(arr_literal, PsArrayInitList)
+    assert arr_literal.shape == shape
+
+    assert arr_literal.structurally_equal(
+        PsArrayInitList(
+            [
+                ((f(1), f(2)), (f(3), f(4)), (f(5), f(6))),
+                ((f(5), f(6)), (f(7), f(8)), (f(9), f(10))),
+            ]
+        )
+    )
+
+
+def test_invalid_arrays():
+    ctx = KernelCreationContext()
+
+    freeze = FreezeExpressions(ctx)
+    #   invalid: nonuniform nesting depth
+    symb_arr = sp.Tuple((3, 32), 14)
+    with pytest.raises(FreezeError):
+        _ = freeze(symb_arr)
+
+    #   invalid: nonuniform sub-array length
+    symb_arr = sp.Tuple((3, 32), (14, -7, 3))
+    with pytest.raises(FreezeError):
+        _ = freeze(symb_arr)
+
+    #   invalid: empty subarray
+    symb_arr = sp.Tuple((), (0, -9))
+    with pytest.raises(FreezeError):
+        _ = freeze(symb_arr)
+
+    #   invalid: all subarrays empty
+    symb_arr = sp.Tuple((), ())
+    with pytest.raises(FreezeError):
+        _ = freeze(symb_arr)
+
+
+def test_memory_access():
+    ctx = KernelCreationContext()
+    freeze = FreezeExpressions(ctx)
+
+    ptr = sp.Symbol("ptr")
+    expr = freeze(mem_acc(ptr, 31))
+
+    assert isinstance(expr, PsMemAcc)
+    assert expr.pointer.structurally_equal(PsExpression.make(ctx.get_symbol("ptr")))
+    assert expr.offset.structurally_equal(PsExpression.make(PsConstant(31)))
+
+
+def test_indexed():
+    ctx = KernelCreationContext()
+    freeze = FreezeExpressions(ctx)
+
+    x, y, z = sp.symbols("x, y, z")
+    a = sp.IndexedBase("a")
+
+    x2 = PsExpression.make(ctx.get_symbol("x"))
+    y2 = PsExpression.make(ctx.get_symbol("y"))
+    z2 = PsExpression.make(ctx.get_symbol("z"))
+    a2 = PsExpression.make(ctx.get_symbol("a"))
+
+    expr = freeze(a[x, y, z])
+    assert expr.structurally_equal(PsSubscript(a2, (x2, y2, z2)))
diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py
index 4c6a4d602..5ea2aa15e 100644
--- a/tests/nbackend/kernelcreation/test_typification.py
+++ b/tests/nbackend/kernelcreation/test_typification.py
@@ -5,6 +5,7 @@ import numpy as np
 from typing import cast
 
 from pystencils import Assignment, TypedSymbol, Field, FieldType, AddAugmentedAssignment
+from pystencils.sympyextensions.pointers import mem_acc
 
 from pystencils.backend.ast.structural import (
     PsDeclaration,
@@ -32,11 +33,12 @@ from pystencils.backend.ast.expressions import (
     PsLt,
     PsCall,
     PsTernary,
+    PsMemAcc
 )
 from pystencils.backend.constants import PsConstant
 from pystencils.backend.functions import CFunction
 from pystencils.types import constify, create_type, create_numeric_type
-from pystencils.types.quick import Fp, Int, Bool, Arr
+from pystencils.types.quick import Fp, Int, Bool, Arr, Ptr
 from pystencils.backend.kernelcreation.context import KernelCreationContext
 from pystencils.backend.kernelcreation.freeze import FreezeExpressions
 from pystencils.backend.kernelcreation.typification import Typifier, TypificationError
@@ -64,7 +66,7 @@ def test_typify_simple():
     assert isinstance(fasm, PsDeclaration)
 
     def check(expr):
-        assert expr.dtype == constify(ctx.default_dtype)
+        assert expr.dtype == ctx.default_dtype
         match expr:
             case PsConstantExpr(cs):
                 assert cs.value == 2
@@ -82,41 +84,6 @@ def test_typify_simple():
     check(fasm.rhs)
 
 
-def test_rhs_constness():
-    default_type = Fp(32)
-    ctx = KernelCreationContext(default_dtype=default_type)
-
-    freeze = FreezeExpressions(ctx)
-    typify = Typifier(ctx)
-
-    f = Field.create_generic(
-        "f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM
-    )
-    f_const = Field.create_generic(
-        "f_const",
-        1,
-        index_shape=(1,),
-        dtype=constify(default_type),
-        field_type=FieldType.CUSTOM,
-    )
-
-    x, y, z = sp.symbols("x, y, z")
-
-    #   Right-hand sides should always get const types
-    asm = typify(freeze(Assignment(x, f.absolute_access([0], [0]))))
-    assert asm.rhs.get_dtype().const
-
-    asm = typify(
-        freeze(
-            Assignment(
-                f.absolute_access([0], [0]),
-                f.absolute_access([0], [0]) * f_const.absolute_access([0], [0]) * x + y,
-            )
-        )
-    )
-    assert asm.rhs.get_dtype().const
-
-
 def test_lhs_constness():
     default_type = Fp(32)
     ctx = KernelCreationContext(default_dtype=default_type)
@@ -136,7 +103,7 @@ def test_lhs_constness():
 
     x, y, z = sp.symbols("x, y, z")
 
-    #   Assignment RHS may not be const
+    #   Can assign to non-const LHS
     asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y)))
     assert not asm.lhs.get_dtype().const
 
@@ -158,7 +125,7 @@ def test_lhs_constness():
     q = ctx.get_symbol("q", Fp(32, const=True))
     ast = PsDeclaration(PsExpression.make(q), PsExpression.make(q))
     ast = typify(ast)
-    assert ast.lhs.dtype == Fp(32, const=True)
+    assert ast.lhs.dtype == Fp(32)
 
     ast = PsAssignment(PsExpression.make(q), PsExpression.make(q))
     with pytest.raises(TypificationError):
@@ -200,7 +167,7 @@ def test_default_typing():
     expr = typify(expr)
 
     def check(expr):
-        assert expr.dtype == constify(ctx.default_dtype)
+        assert expr.dtype == ctx.default_dtype
         match expr:
             case PsConstantExpr(cs):
                 assert cs.value in (2, 3, -4)
@@ -217,6 +184,141 @@ def test_default_typing():
     check(expr)
 
 
+def test_inline_arrays_1d():
+    ctx = KernelCreationContext(default_dtype=Fp(32))
+    freeze = FreezeExpressions(ctx)
+    typify = Typifier(ctx)
+
+    x = sp.Symbol("x")
+    y = TypedSymbol("y", Fp(16))
+    idx = TypedSymbol("idx", Int(32))
+
+    arr: PsArrayInitList = cast(PsArrayInitList, freeze(sp.Tuple(1, 2, 3, 4)))
+    decl = PsDeclaration(freeze(x), freeze(y) + PsSubscript(arr, (freeze(idx),)))
+    #   The array elements should learn their type from the context, which gets it from `y`
+
+    decl = typify(decl)
+    assert decl.lhs.dtype == Fp(16)
+    assert decl.rhs.dtype == Fp(16)
+
+    assert arr.dtype == Arr(Fp(16), (4,))
+    for item in arr.items:
+        assert item.dtype == Fp(16)
+
+
+def test_inline_arrays_3d():
+    ctx = KernelCreationContext(default_dtype=Fp(32))
+    freeze = FreezeExpressions(ctx)
+    typify = Typifier(ctx)
+
+    x = sp.Symbol("x")
+    y = TypedSymbol("y", Fp(16))
+    idx = [TypedSymbol(f"idx_{i}", Int(32)) for i in range(3)]
+
+    arr: PsArrayInitList = freeze(
+        sp.Tuple(((1, 2), (3, 4), (5, 6)), ((5, 6), (7, 8), (9, 10)))
+    )
+    decl = PsDeclaration(
+        freeze(x),
+        freeze(y) + PsSubscript(arr, (freeze(idx[0]), freeze(idx[1]), freeze(idx[2]))),
+    )
+    #   The array elements should learn their type from the context, which gets it from `y`
+
+    decl = typify(decl)
+    assert decl.lhs.dtype == Fp(16)
+    assert decl.rhs.dtype == Fp(16)
+
+    assert arr.dtype == Arr(Fp(16), (2, 3, 2))
+    assert arr.shape == (2, 3, 2)
+    for item in arr.items:
+        assert item.dtype == Fp(16)
+
+
+def test_array_subscript():
+    ctx = KernelCreationContext(default_dtype=Fp(16))
+    freeze = FreezeExpressions(ctx)
+    typify = Typifier(ctx)
+
+    arr = sp.IndexedBase(TypedSymbol("arr", Arr(Fp(32), (16,))))
+    expr = freeze(arr[3])
+    expr = typify(expr)
+
+    assert expr.dtype == Fp(32)
+
+    arr = sp.IndexedBase(TypedSymbol("arr2", Arr(Fp(32), (7, 31))))
+    expr = freeze(arr[3, 5])
+    expr = typify(expr)
+
+    assert expr.dtype == Fp(32)
+
+
+def test_invalid_subscript():
+    ctx = KernelCreationContext(default_dtype=Fp(16))
+    freeze = FreezeExpressions(ctx)
+    typify = Typifier(ctx)
+
+    non_arr = sp.IndexedBase(TypedSymbol("non_arr", Int(64)))
+    expr = freeze(non_arr[3])
+
+    with pytest.raises(TypificationError):
+        expr = typify(expr)
+
+    wrong_shape_arr = sp.IndexedBase(
+        TypedSymbol("wrong_shape_arr", Arr(Fp(32), (7, 31, 5)))
+    )
+    expr = freeze(wrong_shape_arr[3, 5])
+
+    with pytest.raises(TypificationError):
+        expr = typify(expr)
+
+    #   raw pointers are not arrays, cannot enter subscript
+    ptr = sp.IndexedBase(
+        TypedSymbol("ptr", Ptr(Int(16)))
+    )
+    expr = freeze(ptr[37])
+    
+    with pytest.raises(TypificationError):
+        expr = typify(expr)
+
+    
+def test_mem_acc():
+    ctx = KernelCreationContext(default_dtype=Fp(16))
+    freeze = FreezeExpressions(ctx)
+    typify = Typifier(ctx)
+
+    ptr = TypedSymbol("ptr", Ptr(Int(64)))
+    idx = TypedSymbol("idx", Int(32))
+    
+    expr = freeze(mem_acc(ptr, idx))
+    expr = typify(expr)
+
+    assert isinstance(expr, PsMemAcc)
+    assert expr.dtype == Int(64)
+    assert expr.offset.dtype == Int(32)
+
+
+def test_invalid_mem_acc():
+    ctx = KernelCreationContext(default_dtype=Fp(16))
+    freeze = FreezeExpressions(ctx)
+    typify = Typifier(ctx)
+
+    non_ptr = TypedSymbol("non_ptr", Int(64))
+    idx = TypedSymbol("idx", Int(32))
+    
+    expr = freeze(mem_acc(non_ptr, idx))
+    
+    with pytest.raises(TypificationError):
+        _ = typify(expr)
+    
+    arr = TypedSymbol("arr", Arr(Int(64), (31,)))
+    idx = TypedSymbol("idx", Int(32))
+    
+    expr = freeze(mem_acc(arr, idx))
+    
+    with pytest.raises(TypificationError):
+        _ = typify(expr)
+
+
 def test_lhs_inference():
     ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
     freeze = FreezeExpressions(ctx)
@@ -232,13 +334,13 @@ def test_lhs_inference():
     fasm = typify(freeze(asm))
 
     assert ctx.get_symbol("x").dtype == Fp(32)
-    assert fasm.lhs.dtype == constify(Fp(32))
+    assert fasm.lhs.dtype == Fp(32)
 
     asm = Assignment(y, 3 - w)
     fasm = typify(freeze(asm))
 
     assert ctx.get_symbol("y").dtype == Fp(16)
-    assert fasm.lhs.dtype == constify(Fp(16))
+    assert fasm.lhs.dtype == Fp(16)
 
     fasm = PsAssignment(PsExpression.make(ctx.get_symbol("z")), freeze(3 - w))
     fasm = typify(fasm)
@@ -252,8 +354,38 @@ def test_lhs_inference():
     fasm = typify(fasm)
 
     assert ctx.get_symbol("r").dtype == Bool()
-    assert fasm.lhs.dtype == constify(Bool())
-    assert fasm.rhs.dtype == constify(Bool())
+    assert fasm.lhs.dtype == Bool()
+    assert fasm.rhs.dtype == Bool()
+
+
+def test_array_declarations():
+    ctx = KernelCreationContext(default_dtype=Fp(32))
+    freeze = FreezeExpressions(ctx)
+    typify = Typifier(ctx)
+
+    x, y, z = sp.symbols("x, y, z")
+    
+    #   Array type fallback to default
+    arr1 = sp.Symbol("arr1")
+    decl = freeze(Assignment(arr1, sp.Tuple(1, 2, 3, 4)))
+    decl = typify(decl)
+
+    assert ctx.get_symbol("arr1").dtype == Arr(Fp(32), (4,))
+
+    #   Array type determined by default-typed symbol
+    arr2 = sp.Symbol("arr2")
+    decl = freeze(Assignment(arr2, sp.Tuple((x, y, -7), (3, -2, 51))))
+    decl = typify(decl)
+
+    assert ctx.get_symbol("arr2").dtype == Arr(Fp(32), (2, 3))
+
+    #   Array type determined by pre-typed symbol
+    q = TypedSymbol("q", Fp(16))
+    arr3 = sp.Symbol("arr3")
+    decl = freeze(Assignment(arr3, sp.Tuple((q, 2), (-q, 0.123))))
+    decl = typify(decl)
+
+    assert ctx.get_symbol("arr3").dtype == Arr(Fp(16), (2, 2))
 
 
 def test_erronous_typing():
@@ -292,17 +424,17 @@ def test_invalid_indices():
     ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
     typify = Typifier(ctx)
 
-    arr = PsExpression.make(ctx.get_symbol("arr", Arr(Fp(64))))
+    arr = PsExpression.make(ctx.get_symbol("arr", Arr(Fp(64), (61,))))
     x, y, z = [PsExpression.make(ctx.get_symbol(x)) for x in "xyz"]
 
     #   Using default-typed symbols as array indices is illegal when the default type is a float
 
-    fasm = PsAssignment(PsSubscript(arr, x + y), z)
+    fasm = PsAssignment(PsSubscript(arr, (x + y,)), z)
 
     with pytest.raises(TypificationError):
         typify(fasm)
 
-    fasm = PsAssignment(z, PsSubscript(arr, x + y))
+    fasm = PsAssignment(z, PsSubscript(arr, (x + y,)))
 
     with pytest.raises(TypificationError):
         typify(fasm)
@@ -394,7 +526,7 @@ def test_typify_bools_and_relations():
     expr = PsAnd(PsEq(x, y), PsAnd(true, PsNot(PsOr(p, q))))
     expr = typify(expr)
 
-    assert expr.dtype == Bool(const=True)
+    assert expr.dtype == Bool() 
 
 
 def test_bool_in_numerical_context():
@@ -418,7 +550,7 @@ def test_typify_conditionals(rel):
 
     cond = PsConditional(rel(x, y), PsBlock([]))
     cond = typify(cond)
-    assert cond.condition.dtype == Bool(const=True)
+    assert cond.condition.dtype == Bool()
 
 
 def test_invalid_conditions():
@@ -447,11 +579,11 @@ def test_typify_ternary():
 
     expr = PsTernary(p, x, y)
     expr = typify(expr)
-    assert expr.dtype == Fp(32, const=True)
+    assert expr.dtype == Fp(32)
 
     expr = PsTernary(PsAnd(p, q), a, b + a)
     expr = typify(expr)
-    assert expr.dtype == Int(32, const=True)
+    assert expr.dtype == Int(32)
 
     expr = PsTernary(PsAnd(p, q), a, x)
     with pytest.raises(TypificationError):
@@ -475,9 +607,9 @@ def test_cfunction():
 
     result = typify(PsCall(threeway, [x, y]))
 
-    assert result.get_dtype() == Int(32, const=True)
-    assert result.args[0].get_dtype() == Fp(32, const=True)
-    assert result.args[1].get_dtype() == Fp(32, const=True)
+    assert result.get_dtype() == Int(32)
+    assert result.args[0].get_dtype() == Fp(32)
+    assert result.args[1].get_dtype() == Fp(32)
 
     with pytest.raises(TypificationError):
         _ = typify(PsCall(threeway, (x, p)))
diff --git a/tests/nbackend/test_ast.py b/tests/nbackend/test_ast.py
index 09c63a557..02f03bfa9 100644
--- a/tests/nbackend/test_ast.py
+++ b/tests/nbackend/test_ast.py
@@ -3,7 +3,8 @@ from pystencils.backend.constants import PsConstant
 from pystencils.backend.ast.expressions import (
     PsExpression,
     PsCast,
-    PsDeref,
+    PsMemAcc,
+    PsArrayInitList,
     PsSubscript,
 )
 from pystencils.backend.ast.structural import (
@@ -45,6 +46,10 @@ def test_cloning():
         PsConditional(
             y, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")])
         ),
+        PsArrayInitList([
+            [x, y, one + x],
+            [one, c2, z]
+        ]),
         PsPragma("omp parallel for"),
         PsLoop(
             x,
@@ -58,8 +63,8 @@ def test_cloning():
                     PsAssignment(x, y),
                     PsPragma("#pragma clang loop vectorize(enable)"),
                     PsStatement(
-                        PsDeref(PsCast(Ptr(Fp(32)), z))
-                        + PsSubscript(z, one + one + one)
+                        PsMemAcc(PsCast(Ptr(Fp(32)), z), one)
+                        + PsSubscript(z, (one + one + one, y + one))
                     ),
                 ]
             ),
diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py
index 8fb44e748..dc7a86b0b 100644
--- a/tests/nbackend/test_code_printing.py
+++ b/tests/nbackend/test_code_printing.py
@@ -55,14 +55,15 @@ def test_printing_integer_functions():
         PsBitwiseOr,
         PsBitwiseXor,
         PsIntDiv,
-        PsRem
+        PsRem,
     )
 
     expr = PsBitwiseAnd(
         PsBitwiseXor(
             PsBitwiseXor(j, k),
             PsBitwiseOr(PsLeftShift(i, PsRightShift(j, k)), PsIntDiv(i, k)),
-        ) + PsRem(i, k),
+        )
+        + PsRem(i, k),
         i,
     )
     code = cprint(expr)
@@ -154,3 +155,32 @@ def test_ternary():
     expr = PsTernary(PsTernary(p, q, PsOr(p, q)), x, y)
     code = cprint(expr)
     assert code == "(p ? q : p || q) ? x : y"
+
+
+def test_arrays():
+    import sympy as sp
+    from pystencils import Assignment
+    from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory
+
+    ctx = KernelCreationContext(default_dtype=SInt(32))
+    factory = AstFactory(ctx)
+    cprint = CAstPrinter()
+
+    arr_1d = factory.parse_sympy(Assignment(sp.Symbol("a1d"), sp.Tuple(1, 2, 3, 4, 5)))
+    code = cprint(arr_1d)
+    assert code == "int32_t a1d[5] = { 1, 2, 3, 4, 5 };"
+
+    arr_2d = factory.parse_sympy(
+        Assignment(sp.Symbol("a2d"), sp.Tuple((1, -1), (2, -2)))
+    )
+    code = cprint(arr_2d)
+    assert code == "int32_t a2d[2][2] = { { 1, -1 }, { 2, -2 } };"
+
+    arr_3d = factory.parse_sympy(
+        Assignment(sp.Symbol("a3d"), sp.Tuple(((1, -1), (2, -2)), ((3, -3), (4, -4))))
+    )
+    code = cprint(arr_3d)
+    assert (
+        code
+        == "int32_t a3d[2][2][2] = { { { 1, -1 }, { 2, -2 } }, { { 3, -3 }, { 4, -4 } } };"
+    )
diff --git a/tests/nbackend/test_extensions.py b/tests/nbackend/test_extensions.py
index 16e610a55..914d05594 100644
--- a/tests/nbackend/test_extensions.py
+++ b/tests/nbackend/test_extensions.py
@@ -18,13 +18,13 @@ def test_literals():
     f = Field.create_generic("f", 3)
     x = sp.Symbol("x")
     
-    cells = PsExpression.make(PsLiteral("CELLS", Arr(Int(64, const=True), 3)))
+    cells = PsExpression.make(PsLiteral("CELLS", Arr(Int(64), (3,), const=True)))
     global_constant = PsExpression.make(PsLiteral("C", ctx.default_dtype))
 
     loop_slice = make_slice[
-        0:PsSubscript(cells, factory.parse_index(0)),
-        0:PsSubscript(cells, factory.parse_index(1)),
-        0:PsSubscript(cells, factory.parse_index(2)),
+        0:PsSubscript(cells, (factory.parse_index(0),)),
+        0:PsSubscript(cells, (factory.parse_index(1),)),
+        0:PsSubscript(cells, (factory.parse_index(2),)),
     ]
 
     ispace = FullIterationSpace.create_from_slice(ctx, loop_slice)
diff --git a/tests/nbackend/transformations/test_canonical_clone.py b/tests/nbackend/transformations/test_canonical_clone.py
index b158b9178..b5e100ea5 100644
--- a/tests/nbackend/transformations/test_canonical_clone.py
+++ b/tests/nbackend/transformations/test_canonical_clone.py
@@ -22,8 +22,8 @@ def test_clone_entire_ast():
     rho = sp.Symbol("rho")
     u = sp.symbols("u_:2")
 
-    cx = TypedSymbol("cx", Arr(ctx.default_dtype))
-    cy = TypedSymbol("cy", Arr(ctx.default_dtype))
+    cx = TypedSymbol("cx", Arr(ctx.default_dtype, (5,)))
+    cy = TypedSymbol("cy", Arr(ctx.default_dtype, (5,)))
     cxs = sp.IndexedBase(cx, shape=(5,))
     cys = sp.IndexedBase(cy, shape=(5,))
 
diff --git a/tests/nbackend/transformations/test_hoist_invariants.py b/tests/nbackend/transformations/test_hoist_invariants.py
index 15514f1da..daa2760c0 100644
--- a/tests/nbackend/transformations/test_hoist_invariants.py
+++ b/tests/nbackend/transformations/test_hoist_invariants.py
@@ -144,14 +144,14 @@ def test_hoist_arrays():
 
     const_arr_symb = TypedSymbol(
         "const_arr",
-        Arr(Fp(64, const=True), 10),
+        Arr(Fp(64), (10,), const=True),
     )
     const_array_decl = factory.parse_sympy(Assignment(const_arr_symb, tuple(range(10))))
     const_arr = sp.IndexedBase(const_arr_symb, shape=(10,))
 
     arr_symb = TypedSymbol(
         "arr",
-        Arr(Fp(64, const=False), 10),
+        Arr(Fp(64), (10,), const=False),
     )
     array_decl = factory.parse_sympy(Assignment(arr_symb, tuple(range(10))))
     arr = sp.IndexedBase(arr_symb, shape=(10,))
diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py
index 1cc2ae0e4..165d572de 100644
--- a/tests/nbackend/types/test_types.py
+++ b/tests/nbackend/types/test_types.py
@@ -151,6 +151,48 @@ def test_struct_types():
     assert t.itemsize == numpy_type.itemsize == 16
 
 
+def test_array_types():
+    t = PsArrayType(UInt(64), 42)
+    assert t.dim == 1
+    assert t.shape == (42,)
+    assert not t.const
+    assert t.c_string() == "uint64_t[42]"
+
+    assert t == PsArrayType(UInt(64), (42,))
+
+    t = PsArrayType(UInt(64), [3, 4, 5])
+    assert t.dim == 3
+    assert t.shape == (3, 4, 5)
+    assert not t.const
+    assert t.c_string() == "uint64_t[3][4][5]"
+
+    t = PsArrayType(UInt(64, const=True), [3, 4, 5])
+    assert t.dim == 3
+    assert t.shape == (3, 4, 5)
+    assert not t.const
+
+    t = PsArrayType(UInt(64), [3, 4, 5], const=True)
+    assert t.dim == 3
+    assert t.shape == (3, 4, 5)
+    assert t.const
+    assert t.c_string() == "const uint64_t[3][4][5]"
+
+    t = PsArrayType(UInt(64, const=True), [3, 4, 5], const=True)
+    assert t.dim == 3
+    assert t.shape == (3, 4, 5)
+    assert t.const
+
+    with pytest.raises(ValueError):
+        _ = PsArrayType(UInt(64), (3, 0, 1))
+
+    with pytest.raises(ValueError):
+        _ = PsArrayType(UInt(64), (3, 9, -1, 2))
+
+    #   Nested arrays are disallowed
+    with pytest.raises(ValueError):
+        _ = PsArrayType(PsArrayType(Bool(), (2,)), (3, 1))
+
+
 def test_pickle():
     types = [
         Bool(const=True),
@@ -165,8 +207,8 @@ def test_pickle():
         Fp(width=16, const=True),
         PsStructType([("x", UInt(32)), ("y", UInt(32)), ("val", Fp(64))], "myStruct"),
         PsStructType([("data", Fp(32))], "None"),
-        PsArrayType(Fp(16, const=True), 42),
-        PsArrayType(PsVectorType(Fp(32), 8, const=False), 42)
+        PsArrayType(Fp(16), (42,), const=True),
+        PsArrayType(PsVectorType(Fp(32), 8), (42,))
     ]
 
     dumped = pickle.dumps(types)
-- 
GitLab