From 4bbfb3b65bb91ca9e697ec1abfde84ffa70b446e Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 21 Oct 2024 08:48:31 +0200 Subject: [PATCH] Revised Array Modelling & Memory Model --- .../source/api/symbolic_language/astnodes.rst | 2 +- docs/source/backend/objects.rst | 41 ++- src/pystencils/backend/ast/expressions.py | 161 ++++++++--- src/pystencils/backend/ast/util.py | 49 +++- src/pystencils/backend/emission.py | 51 ++-- .../backend/kernelcreation/ast_factory.py | 6 +- .../backend/kernelcreation/freeze.py | 35 ++- .../backend/kernelcreation/typification.py | 253 +++++++++--------- src/pystencils/backend/literals.py | 2 +- src/pystencils/backend/platforms/sycl.py | 4 +- src/pystencils/backend/platforms/x86.py | 6 +- .../erase_anonymous_structs.py | 7 +- .../hoist_loop_invariant_decls.py | 10 +- src/pystencils/sympyextensions/pointers.py | 11 + src/pystencils/types/types.py | 54 +++- tests/nbackend/kernelcreation/test_freeze.py | 101 +++++++ .../kernelcreation/test_typification.py | 240 +++++++++++++---- tests/nbackend/test_ast.py | 11 +- tests/nbackend/test_code_printing.py | 34 ++- tests/nbackend/test_extensions.py | 8 +- .../transformations/test_canonical_clone.py | 4 +- .../transformations/test_hoist_invariants.py | 4 +- tests/nbackend/types/test_types.py | 46 +++- 23 files changed, 853 insertions(+), 287 deletions(-) diff --git a/docs/source/api/symbolic_language/astnodes.rst b/docs/source/api/symbolic_language/astnodes.rst index 4d5c4b89f..ff31c98ec 100644 --- a/docs/source/api/symbolic_language/astnodes.rst +++ b/docs/source/api/symbolic_language/astnodes.rst @@ -4,6 +4,6 @@ Kernel Structure .. automodule:: pystencils.sympyextensions.astnodes -.. autoclass:: pystencils.sympyextensions.AssignmentCollection +.. autoclass:: pystencils.AssignmentCollection :members: diff --git a/docs/source/backend/objects.rst b/docs/source/backend/objects.rst index b0c3af6db..1b36842b8 100644 --- a/docs/source/backend/objects.rst +++ b/docs/source/backend/objects.rst @@ -1,15 +1,44 @@ -***************************** -Symbols, Constants and Arrays -***************************** +**************************** +Constants and Memory Objects +**************************** + +Memory Objects: Symbols and Field Arrays +======================================== + +The Memory Model +---------------- + +In order to reason about memory accesses, mutability, invariance, and aliasing, the *pystencils* backend uses +a very simple memory model. There are three types of memory objects: + +- Symbols (`PsSymbol`), which act as registers for data storage within the scope of a kernel +- Field arrays (`PsLinearizedArray`), which represent a contiguous block of memory the kernel has access to, and +- the *unmanaged heap*, which is a global catch-all memory object which all pointers not belonging to a field + array point into. + +All of these objects are disjoint, and cannot alias each other. +Each symbol exists in isolation, +field arrays do not overlap, +and raw pointers are assumed not to point into memory owned by a symbol or field array. +Instead, all raw pointers point into unmanaged heap memory, and are assumed to *always* alias one another: +Each change brought to unmanaged memory by one raw pointer is assumed to affect the memory pointed to by +another raw pointer. + +Classes +------- .. autoclass:: pystencils.backend.symbols.PsSymbol :members: -.. autoclass:: pystencils.backend.constants.PsConstant +.. automodule:: pystencils.backend.arrays :members: -.. autoclass:: pystencils.backend.literals.PsLiteral + +Constants and Literals +====================== + +.. autoclass:: pystencils.backend.constants.PsConstant :members: -.. automodule:: pystencils.backend.arrays +.. autoclass:: pystencils.backend.literals.PsLiteral :members: diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 5f9c95d5d..151f86c6e 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -1,8 +1,12 @@ from __future__ import annotations + from abc import ABC, abstractmethod from typing import Sequence, overload, Callable, Any, cast import operator +import numpy as np +from numpy.typing import NDArray + from ..symbols import PsSymbol from ..constants import PsConstant from ..literals import PsLiteral @@ -99,7 +103,8 @@ class PsExpression(PsAstNode, ABC): class PsLvalue(ABC): - """Mix-in for all expressions that may occur as an lvalue""" + """Mix-in for all expressions that may occur as an lvalue; + i.e. expressions that represent a memory location.""" class PsSymbolExpr(PsLeafMixIn, PsLvalue, PsExpression): @@ -189,48 +194,99 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression): class PsSubscript(PsLvalue, PsExpression): - __match_args__ = ("base", "index") + """N-dimensional subscript into an array.""" + + __match_args__ = ("array", "index") - def __init__(self, base: PsExpression, index: PsExpression): + def __init__(self, arr: PsExpression, index: Sequence[PsExpression]): super().__init__() - self._base = base - self._index = index + self._arr = arr + + if not index: + raise ValueError("Subscript index cannot be empty.") + + self._index = list(index) @property - def base(self) -> PsExpression: - return self._base + def array(self) -> PsExpression: + return self._arr - @base.setter - def base(self, expr: PsExpression): - self._base = expr + @array.setter + def array(self, expr: PsExpression): + self._arr = expr @property - def index(self) -> PsExpression: + def index(self) -> list[PsExpression]: return self._index @index.setter - def index(self, expr: PsExpression): - self._index = expr + def index(self, idx: Sequence[PsExpression]): + self._index = list(idx) def clone(self) -> PsSubscript: - return PsSubscript(self._base.clone(), self._index.clone()) + return PsSubscript(self._arr.clone(), [i.clone() for i in self._index]) def get_children(self) -> tuple[PsAstNode, ...]: - return (self._base, self._index) + return (self._arr,) + tuple(self._index) + + def set_child(self, idx: int, c: PsAstNode): + idx = range(len(self._index) + 1)[idx] + match idx: + case 0: + self.array = failing_cast(PsExpression, c) + case _: + self.index[idx - 1] = failing_cast(PsExpression, c) + + def __repr__(self) -> str: + idx = ", ".join(repr(i) for i in self._index) + return f"PsSubscript({self._arr}, ({idx}))" + + +class PsMemAcc(PsLvalue, PsExpression): + """Pointer-based memory access with type-dependent offset.""" + + __match_args__ = ("pointer", "offset") + + def __init__(self, ptr: PsExpression, offset: PsExpression): + super().__init__() + self._ptr = ptr + self._offset = offset + + @property + def pointer(self) -> PsExpression: + return self._ptr + + @pointer.setter + def pointer(self, expr: PsExpression): + self._ptr = expr + + @property + def offset(self) -> PsExpression: + return self._offset + + @offset.setter + def offset(self, expr: PsExpression): + self._offset = expr + + def clone(self) -> PsMemAcc: + return PsMemAcc(self._ptr.clone(), self._offset.clone()) + + def get_children(self) -> tuple[PsAstNode, ...]: + return (self._ptr, self._offset) def set_child(self, idx: int, c: PsAstNode): idx = [0, 1][idx] match idx: case 0: - self.base = failing_cast(PsExpression, c) + self.pointer = failing_cast(PsExpression, c) case 1: - self.index = failing_cast(PsExpression, c) + self.offset = failing_cast(PsExpression, c) def __repr__(self) -> str: - return f"Subscript({self._base})[{self._index}]" + return f"PsMemAcc({repr(self._ptr)}, {repr(self._offset)})" -class PsArrayAccess(PsSubscript): +class PsArrayAccess(PsMemAcc): __match_args__ = ("base_ptr", "index") def __init__(self, base_ptr: PsArrayBasePointer, index: PsExpression): @@ -243,11 +299,11 @@ class PsArrayAccess(PsSubscript): return self._base_ptr @property - def base(self) -> PsExpression: - return self._base + def pointer(self) -> PsExpression: + return self._ptr - @base.setter - def base(self, expr: PsExpression): + @pointer.setter + def pointer(self, expr: PsExpression): if not isinstance(expr, PsSymbolExpr) or not isinstance( expr.symbol, PsArrayBasePointer ): @@ -256,17 +312,25 @@ class PsArrayAccess(PsSubscript): ) self._base_ptr = expr.symbol - self._base = expr + self._ptr = expr @property def array(self) -> PsLinearizedArray: return self._base_ptr.array + + @property + def index(self) -> PsExpression: + return self._offset + + @index.setter + def index(self, expr: PsExpression): + self._offset = expr def clone(self) -> PsArrayAccess: - return PsArrayAccess(self._base_ptr, self._index.clone()) + return PsArrayAccess(self._base_ptr, self._offset.clone()) def __repr__(self) -> str: - return f"ArrayAccess({repr(self._base_ptr)}, {repr(self._index)})" + return f"PsArrayAccess({repr(self._base_ptr)}, {repr(self._offset)})" class PsVectorArrayAccess(PsArrayAccess): @@ -314,7 +378,7 @@ class PsVectorArrayAccess(PsArrayAccess): def clone(self) -> PsVectorArrayAccess: return PsVectorArrayAccess( self._base_ptr, - self._index.clone(), + self._offset.clone(), self.vector_entries, self._stride, self._alignment, @@ -525,11 +589,16 @@ class PsNeg(PsUnOp, PsNumericOpTrait): return operator.neg -class PsDeref(PsLvalue, PsUnOp): - pass - - class PsAddressOf(PsUnOp): + """Take the address of a memory location. + + .. DANGER:: + Taking the address of a memory location owned by a symbol or field array + introduces an alias to that memory location. + As pystencils assumes its symbols and fields to never be aliased, this can + subtly change the semantics of a kernel. + Use the address-of operator with utmost care. + """ pass @@ -740,32 +809,44 @@ class PsLt(PsRel): class PsArrayInitList(PsExpression): + """N-dimensional array initialization matrix.""" + __match_args__ = ("items",) - def __init__(self, items: Sequence[PsExpression]): + def __init__(self, items: Sequence[PsExpression | Sequence[PsExpression | Sequence[PsExpression]]]): super().__init__() - self._items = list(items) + self._items = np.array(items, dtype=np.object_) @property - def items(self) -> list[PsExpression]: + def items_grid(self) -> NDArray[np.object_]: return self._items + + @property + def shape(self) -> tuple[int, ...]: + return self._items.shape + + @property + def items(self) -> tuple[PsExpression, ...]: + return tuple(self._items.flat) # type: ignore def get_children(self) -> tuple[PsAstNode, ...]: - return tuple(self._items) + return tuple(self._items.flat) # type: ignore def set_child(self, idx: int, c: PsAstNode): - self._items[idx] = failing_cast(PsExpression, c) + self._items.flat[idx] = failing_cast(PsExpression, c) def clone(self) -> PsExpression: - return PsArrayInitList([expr.clone() for expr in self._items]) + return PsArrayInitList( + np.array([expr.clone() for expr in self.children]).reshape( # type: ignore + self._items.shape + ) + ) def __repr__(self) -> str: return f"PsArrayInitList({repr(self._items)})" -def evaluate_expression( - expr: PsExpression, valuation: dict[str, Any] -) -> Any: +def evaluate_expression(expr: PsExpression, valuation: dict[str, Any]) -> Any: """Evaluate a pystencils backend expression tree with values assigned to symbols according to the given valuation. Only a subset of expression nodes can be processed by this evaluator. diff --git a/src/pystencils/backend/ast/util.py b/src/pystencils/backend/ast/util.py index 0d3b78629..72aff0a01 100644 --- a/src/pystencils/backend/ast/util.py +++ b/src/pystencils/backend/ast/util.py @@ -1,8 +1,15 @@ from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, cast + +from ..exceptions import PsInternalCompilerError +from ..symbols import PsSymbol +from ..arrays import PsLinearizedArray +from ...types import PsDereferencableType + if TYPE_CHECKING: from .astnode import PsAstNode + from .expressions import PsExpression def failing_cast(target: type | tuple[type, ...], obj: Any) -> Any: @@ -36,3 +43,43 @@ class AstEqWrapper: # TODO: consider replacing this with smth. more performant # TODO: Check that repr is implemented by all AST nodes return hash(repr(self._node)) + + +def determine_memory_object( + expr: PsExpression, +) -> tuple[PsSymbol | PsLinearizedArray | None, bool]: + """Return the memory object accessed by the given expression, together with its constness + + Returns: + Tuple ``(mem_obj, const)`` identifying the memory object accessed by the given expression, + as well as its constness + """ + from pystencils.backend.ast.expressions import ( + PsSubscript, + PsLookup, + PsSymbolExpr, + PsMemAcc, + PsArrayAccess, + ) + + while isinstance(expr, (PsSubscript, PsLookup)): + match expr: + case PsSubscript(arr, _): + expr = arr + case PsLookup(record, _): + expr = record + + match expr: + case PsSymbolExpr(symb): + return symb, symb.get_dtype().const + case PsMemAcc(ptr, _): + return None, cast(PsDereferencableType, ptr.get_dtype()).base_type.const + case PsArrayAccess(ptr, _): + return ( + expr.array, + cast(PsDereferencableType, ptr.get_dtype()).base_type.const, + ) + case _: + raise PsInternalCompilerError( + "The given expression is a transient and does not refer to a memory object" + ) diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 8928cc689..579d47648 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -16,6 +16,7 @@ from .ast.structural import ( ) from .ast.expressions import ( + PsExpression, PsAdd, PsAddressOf, PsArrayInitList, @@ -26,7 +27,7 @@ from .ast.expressions import ( PsCall, PsCast, PsConstantExpr, - PsDeref, + PsMemAcc, PsDiv, PsRem, PsIntDiv, @@ -36,7 +37,6 @@ from .ast.expressions import ( PsNeg, PsRightShift, PsSub, - PsSubscript, PsSymbolExpr, PsLiteralExpr, PsVectorArrayAccess, @@ -50,6 +50,7 @@ from .ast.expressions import ( PsLt, PsGe, PsLe, + PsSubscript ) from .extensions.foreign_ast import PsForeignExpression @@ -270,16 +271,27 @@ class CAstPrinter: case PsVectorArrayAccess(): raise EmissionError("Cannot print vectorized array accesses") - case PsSubscript(base, index): + case PsMemAcc(base, offset): pc.push_op(Ops.Subscript, LR.Left) base_code = self.visit(base, pc) pc.pop_op() pc.push_op(Ops.Weakest, LR.Middle) - index_code = self.visit(index, pc) + index_code = self.visit(offset, pc) pc.pop_op() return pc.parenthesize(f"{base_code}[{index_code}]", Ops.Subscript) + + case PsSubscript(base, indices): + pc.push_op(Ops.Subscript, LR.Left) + base_code = self.visit(base, pc) + pc.pop_op() + + pc.push_op(Ops.Weakest, LR.Middle) + indices_code = "".join("[" + self.visit(idx, pc) + "]" for idx in indices) + pc.pop_op() + + return pc.parenthesize(base_code + indices_code, Ops.Subscript) case PsLookup(aggr, member_name): pc.push_op(Ops.Lookup, LR.Left) @@ -320,12 +332,12 @@ class CAstPrinter: return pc.parenthesize(f"!{operand_code}", Ops.Not) - case PsDeref(operand): - pc.push_op(Ops.Deref, LR.Right) - operand_code = self.visit(operand, pc) - pc.pop_op() + # case PsDeref(operand): + # pc.push_op(Ops.Deref, LR.Right) + # operand_code = self.visit(operand, pc) + # pc.pop_op() - return pc.parenthesize(f"*{operand_code}", Ops.Deref) + # return pc.parenthesize(f"*{operand_code}", Ops.Deref) case PsAddressOf(operand): pc.push_op(Ops.AddressOf, LR.Right) @@ -355,11 +367,19 @@ class CAstPrinter: f"{cond_code} ? {then_code} : {else_code}", Ops.Ternary ) - case PsArrayInitList(items): + case PsArrayInitList(_): + def print_arr(item) -> str: + if isinstance(item, PsExpression): + return self.visit(item, pc) + else: + # it's a subarray + entries = ", ".join(print_arr(i) for i in item) + return "{ " + entries + " }" + pc.push_op(Ops.Weakest, LR.Middle) - items_str = ", ".join(self.visit(item, pc) for item in items) + arr_str = print_arr(node.items_grid) pc.pop_op() - return "{ " + items_str + " }" + return arr_str case PsForeignExpression(children): pc.push_op(Ops.Weakest, LR.Middle) @@ -379,10 +399,11 @@ class CAstPrinter: def _symbol_decl(self, symb: PsSymbol): dtype = symb.get_dtype() - array_dims = [] - while isinstance(dtype, PsArrayType): - array_dims.append(dtype.length) + if isinstance(dtype, PsArrayType): + array_dims = dtype.shape dtype = dtype.base_type + else: + array_dims = () code = f"{dtype.c_string()} {symb.name}" for d in array_dims: diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index a0328a123..0dc60b1b1 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -12,7 +12,7 @@ from ..symbols import PsSymbol from ..constants import PsConstant from .context import KernelCreationContext -from .freeze import FreezeExpressions +from .freeze import FreezeExpressions, ExprLike from .typification import Typifier from .iteration_space import FullIterationSpace @@ -42,14 +42,14 @@ class AstFactory: pass @overload - def parse_sympy(self, sp_obj: sp.Expr) -> PsExpression: + def parse_sympy(self, sp_obj: ExprLike) -> PsExpression: pass @overload def parse_sympy(self, sp_obj: AssignmentBase) -> PsAssignment: pass - def parse_sympy(self, sp_obj: sp.Expr | AssignmentBase) -> PsAstNode: + def parse_sympy(self, sp_obj: ExprLike | AssignmentBase) -> PsAstNode: """Parse a SymPy expression or assignment through `FreezeExpressions` and `Typifier`. The expression or assignment will be typified in a numerical context, using the kernel diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index ff7754ac2..0ae5a0d1b 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -14,7 +14,7 @@ from ...sympyextensions import ( ConditionalFieldAccess, ) from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType -from ...sympyextensions.pointers import AddressOf +from ...sympyextensions.pointers import AddressOf, mem_acc from ...field import Field, FieldType from .context import KernelCreationContext @@ -55,6 +55,7 @@ from ..ast.expressions import ( PsAnd, PsOr, PsNot, + PsMemAcc ) from ..constants import PsConstant @@ -275,16 +276,36 @@ class FreezeExpressions: return PsSymbolExpr(symb) def map_Tuple(self, expr: sp.Tuple) -> PsArrayInitList: + if not expr: + raise FreezeError("Cannot translate an empty tuple.") + items = [self.visit_expr(item) for item in expr] - return PsArrayInitList(items) + + if any(isinstance(i, PsArrayInitList) for i in items): + # base case: have nested arrays + if not all(isinstance(i, PsArrayInitList) for i in items): + raise FreezeError( + f"Cannot translate nested arrays of non-uniform shape: {expr}" + ) + + subarrays = cast(list[PsArrayInitList], items) + shape_tail = subarrays[0].shape + + if not all(s.shape == shape_tail for s in subarrays[1:]): + raise FreezeError( + f"Cannot translate nested arrays of non-uniform shape: {expr}" + ) + + return PsArrayInitList([s.items_grid for s in subarrays]) # type: ignore + else: + # base case: no nested arrays + return PsArrayInitList(items) def map_Indexed(self, expr: sp.Indexed) -> PsSubscript: assert isinstance(expr.base, sp.IndexedBase) base = self.visit_expr(expr.base.label) - subscript = PsSubscript(base, self.visit_expr(expr.indices[0])) - for idx in expr.indices[1:]: - subscript = PsSubscript(subscript, self.visit_expr(idx)) - return subscript + indices = [self.visit_expr(i) for i in expr.indices] + return PsSubscript(base, indices) def map_Access(self, access: Field.Access): field = access.field @@ -429,6 +450,8 @@ class FreezeExpressions: ) case AddressOf(): return PsAddressOf(*args) + case mem_acc(): + return PsMemAcc(*args) case _: raise FreezeError(f"Unsupported function: {func}") diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index ce2d24f98..819d4a12b 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TypeVar +from typing import TypeVar, Callable from .context import KernelCreationContext from ...types import ( @@ -35,11 +35,11 @@ from ..ast.expressions import ( PsCall, PsTernary, PsCast, - PsDeref, PsAddressOf, PsConstantExpr, PsLookup, PsSubscript, + PsMemAcc, PsSymbolExpr, PsLiteralExpr, PsRel, @@ -47,6 +47,7 @@ from ..ast.expressions import ( PsNot, ) from ..functions import PsMathFunction, CFunction +from ..ast.util import determine_memory_object __all__ = ["Typifier"] @@ -57,38 +58,40 @@ class TypificationError(Exception): NodeT = TypeVar("NodeT", bound=PsAstNode) +ResolutionHook = Callable[[PsType], None] + class TypeContext: """Typing context, with support for type inference and checking. Instances of this class are used to propagate and check data types across expression subtrees - of the AST. Each type context has: - - - A target type `target_type`, which shall be applied to all expressions it covers - - A set of restrictions on the target type: - - `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides - - Additional restrictions may be added in the future. + of the AST. Each type context has a target type `target_type`, which shall be applied to all expressions it covers """ def __init__( self, target_type: PsType | None = None, - require_nonconst: bool = False, ): - self._require_nonconst = require_nonconst self._deferred_exprs: list[PsExpression] = [] - self._target_type = ( - self._fix_constness(target_type) if target_type is not None else None - ) + self._target_type = deconstify(target_type) if target_type is not None else None + + self._hooks: list[ResolutionHook] = [] @property def target_type(self) -> PsType | None: return self._target_type - @property - def require_nonconst(self) -> bool: - return self._require_nonconst + def add_hook(self, hook: ResolutionHook): + """Adds a resolution hook to this context. + + The hook will be called with the context's target type as soon as it becomes known, + which might be immediately. + """ + if self._target_type is None: + self._hooks.append(hook) + else: + hook(self._target_type) def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None): """Applies the given ``dtype`` to this type context, and optionally to the given expression. @@ -102,7 +105,7 @@ class TypeContext: and will be replaced by it. """ - dtype = self._fix_constness(dtype) + dtype = deconstify(dtype) if self._target_type is not None and dtype != self._target_type: raise TypificationError( @@ -134,6 +137,12 @@ class TypeContext: self._apply_target_type(expr) def _propagate_target_type(self): + assert self._target_type is not None + + for hook in self._hooks: + hook(self._target_type) + self._hooks = [] + for expr in self._deferred_exprs: self._apply_target_type(expr) self._deferred_exprs = [] @@ -211,30 +220,10 @@ class TypeContext: def _compatible(self, dtype: PsType): """Checks whether the given data type is compatible with the context's target type. - - If the target type is ``const``, they must be equal up to const qualification; - if the target type is not ``const``, `dtype` must match it exactly. + The two must match except for constness. """ assert self._target_type is not None - if self._target_type.const: - return constify(dtype) == self._target_type - else: - return dtype == self._target_type - - def _fix_constness(self, dtype: PsType, expr: PsExpression | None = None): - if self._require_nonconst: - if dtype.const: - if expr is None: - raise TypificationError( - f"Type mismatch: Encountered {dtype} in non-constant context." - ) - else: - raise TypificationError( - f"Type mismatch at expression {expr}: Encountered {dtype} in non-constant context." - ) - return dtype - else: - return constify(dtype) + return deconstify(dtype) == self._target_type class Typifier: @@ -276,11 +265,7 @@ class Typifier: **Typing of symbol expressions** - Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types, but - not necessarily their const-qualification. - A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type, - and an array base pointer with non-``const`` base type may be nested in a ``const`` `PsArrayAccess`, - but not vice versa. + Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types. """ def __init__(self, ctx: KernelCreationContext): @@ -316,7 +301,44 @@ class Typifier: for s in statements: self.visit(s) - case PsDeclaration(lhs, rhs): + case PsDeclaration(lhs, rhs) if isinstance(rhs, PsArrayInitList): + # Special treatment for array declarations + assert isinstance(lhs, PsSymbolExpr) + + decl_tc = TypeContext() + items_tc = TypeContext() + + if (lhs_type := lhs.symbol.dtype) is not None: + if not isinstance(lhs_type, PsArrayType): + raise TypificationError( + f"Illegal LHS type in array declaration: {lhs_type}" + ) + + if lhs_type.shape != rhs.shape: + raise TypificationError( + f"Incompatible shapes in declaration of array symbol {lhs.symbol}.\n" + f" Symbol shape: {lhs_type.shape}\n" + f" Array shape: {rhs.shape}" + ) + + items_tc.apply_dtype(lhs_type.base_type) + decl_tc.apply_dtype(lhs_type, lhs) + else: + decl_tc.infer_dtype(lhs) + + for item in rhs.items: + self.visit_expr(item, items_tc) + + if items_tc.target_type is None: + items_tc.apply_dtype(self._ctx.default_dtype) + + if decl_tc.target_type is None: + assert items_tc.target_type is not None + decl_tc.apply_dtype( + PsArrayType(items_tc.target_type, rhs.shape), rhs + ) + + case PsDeclaration(lhs, rhs) | PsAssignment(lhs, rhs): # Only if the LHS is an untyped symbol, infer its type from the RHS infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None @@ -333,27 +355,11 @@ class Typifier: if infer_lhs and tc.target_type is None: # no type has been inferred -> use the default dtype tc.apply_dtype(self._ctx.default_dtype) - - case PsAssignment(lhs, rhs): - infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None - - tc_lhs = TypeContext(require_nonconst=True) - - if infer_lhs: - tc_lhs.infer_dtype(lhs) - else: - self.visit_expr(lhs, tc_lhs) - assert tc_lhs.target_type is not None - - tc_rhs = TypeContext(target_type=tc_lhs.target_type) - self.visit_expr(rhs, tc_rhs) - - if infer_lhs: - if tc_rhs.target_type is None: - tc_rhs.apply_dtype(self._ctx.default_dtype) - - assert tc_rhs.target_type is not None - tc_lhs.apply_dtype(deconstify(tc_rhs.target_type)) + elif not isinstance(node, PsDeclaration): + # check mutability of LHS + _, lhs_const = determine_memory_object(lhs) + if lhs_const: + raise TypificationError(f"Cannot assign to immutable LHS {lhs}") case PsConditional(cond, branch_true, branch_false): cond_tc = TypeContext(PsBoolType()) @@ -409,49 +415,56 @@ class Typifier: case PsArrayAccess(bptr, idx): tc.apply_dtype(bptr.array.element_type, expr) + self._handle_idx(idx) + + case PsMemAcc(ptr, offset): + ptr_tc = TypeContext() + self.visit_expr(ptr, ptr_tc) - index_tc = TypeContext() - self.visit_expr(idx, index_tc) - if index_tc.target_type is None: - index_tc.apply_dtype(self._ctx.index_dtype, idx) - elif not isinstance(index_tc.target_type, PsIntegerType): + if not isinstance(ptr_tc.target_type, PsPointerType): raise TypificationError( - f"Array index is not of integer type: {idx} has type {index_tc.target_type}" + f"Type of pointer argument to memory access was not a pointer type: {ptr_tc.target_type}" ) - case PsSubscript(arr, idx): - arr_tc = TypeContext() - self.visit_expr(arr, arr_tc) + tc.apply_dtype(ptr_tc.target_type.base_type, expr) + self._handle_idx(offset) - if not isinstance(arr_tc.target_type, PsDereferencableType): - raise TypificationError( - "Type of subscript base is not subscriptable." - ) + case PsSubscript(arr, indices): + if isinstance(arr, PsArrayInitList): + shape = arr.shape - tc.apply_dtype(arr_tc.target_type.base_type, expr) + # extend outer context over the init-list entries + for item in arr.items: + self.visit_expr(item, tc) - index_tc = TypeContext() - self.visit_expr(idx, index_tc) - if index_tc.target_type is None: - index_tc.apply_dtype(self._ctx.index_dtype, idx) - elif not isinstance(index_tc.target_type, PsIntegerType): - raise TypificationError( - f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}" - ) + # learn the array type from the items + def arr_hook(element_type: PsType): + arr.dtype = PsArrayType(element_type, arr.shape) - case PsDeref(ptr): - ptr_tc = TypeContext() - self.visit_expr(ptr, ptr_tc) + tc.add_hook(arr_hook) + else: + # type of array has to be known + arr_tc = TypeContext() + self.visit_expr(arr, arr_tc) - if not isinstance(ptr_tc.target_type, PsDereferencableType): + if not isinstance(arr_tc.target_type, PsArrayType): + raise TypificationError( + f"Type of array argument to subscript was not an array type: {arr_tc.target_type}" + ) + + tc.apply_dtype(arr_tc.target_type.base_type, expr) + shape = arr_tc.target_type.shape + + if len(indices) != len(shape): raise TypificationError( - "Type of argument to a Deref is not dereferencable" + f"Invalid number of indices to {len(shape)}-dimensional array: {len(indices)}" ) - tc.apply_dtype(ptr_tc.target_type.base_type, expr) + for idx in indices: + self._handle_idx(idx) case PsAddressOf(arg): - if not isinstance(arg, (PsSymbolExpr, PsSubscript, PsDeref, PsLookup)): + if not isinstance(arg, (PsSymbolExpr, PsSubscript, PsMemAcc, PsLookup)): raise TypificationError( f"Illegal expression below AddressOf operator: {arg}" ) @@ -468,8 +481,8 @@ class Typifier: match arg: case PsSymbolExpr(s): pointed_to_type = s.get_dtype() - case PsSubscript(arr, _) | PsDeref(arr): - arr_type = arr.get_dtype() + case PsSubscript(ptr, _) | PsMemAcc(ptr, _): + arr_type = ptr.get_dtype() assert isinstance(arr_type, PsDereferencableType) pointed_to_type = arr_type.base_type case PsLookup(aggr, member_name): @@ -491,7 +504,7 @@ class Typifier: case PsLookup(aggr, member_name): # Members of a struct type inherit the struct type's `const` qualifier - aggr_tc = TypeContext(None, require_nonconst=tc.require_nonconst) + aggr_tc = TypeContext() self.visit_expr(aggr, aggr_tc) aggr_type = aggr_tc.target_type @@ -566,32 +579,11 @@ class Typifier: f"Don't know how to typify calls to {function}" ) - case PsArrayInitList(items): - items_tc = TypeContext() - for item in items: - self.visit_expr(item, items_tc) - - if items_tc.target_type is None: - if tc.target_type is None: - raise TypificationError(f"Unable to infer type of array {expr}") - elif not isinstance(tc.target_type, PsArrayType): - raise TypificationError( - f"Cannot apply type {tc.target_type} to an array initializer." - ) - elif ( - tc.target_type.length is not None - and tc.target_type.length != len(items) - ): - raise TypificationError( - "Array size mismatch: Cannot typify initializer list with " - f"{len(items)} items as {tc.target_type}" - ) - else: - items_tc.apply_dtype(tc.target_type.base_type) - tc.infer_dtype(expr) - else: - arr_type = PsArrayType(items_tc.target_type, len(items)) - tc.apply_dtype(arr_type, expr) + case PsArrayInitList(_): + raise TypificationError( + "Unable to typify array initializer in isolation.\n" + f" Array: {expr}" + ) case PsCast(dtype, arg): arg_tc = TypeContext() @@ -606,3 +598,16 @@ class Typifier: case _: raise NotImplementedError(f"Can't typify {expr}") + + def _handle_idx(self, idx: PsExpression): + index_tc = TypeContext() + self.visit_expr(idx, index_tc) + + if index_tc.target_type is None: + index_tc.apply_dtype(self._ctx.index_dtype, idx) + elif not isinstance(index_tc.target_type, PsIntegerType): + raise TypificationError( + f"Invalid data type in index expression.\n" + f" Expression: {idx}\n" + f" Type: {index_tc.target_type}" + ) diff --git a/src/pystencils/backend/literals.py b/src/pystencils/backend/literals.py index dc254da0e..976e6b203 100644 --- a/src/pystencils/backend/literals.py +++ b/src/pystencils/backend/literals.py @@ -6,7 +6,7 @@ class PsLiteral: """Representation of literal code. Instances of this class represent code literals inside the AST. - These literals are not to be confused with C literals; the name `Literal` refers to the fact that + These literals are not to be confused with C literals; the name 'Literal' refers to the fact that the code generator takes them "literally", printing them as they are. Each literal has to be annotated with a type, and is considered constant within the scope of a kernel. diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py index 39ec09992..fa42ed021 100644 --- a/src/pystencils/backend/platforms/sycl.py +++ b/src/pystencils/backend/platforms/sycl.py @@ -121,7 +121,7 @@ class SyclPlatform(GenericGpu): for i, dim in enumerate(dimensions): # Slowest to fastest coord = PsExpression.make(PsConstant(i, self._ctx.index_dtype)) - work_item_idx = PsSubscript(id_symbol, coord) + work_item_idx = PsSubscript(id_symbol, (coord,)) dim.counter.dtype = constify(dim.counter.get_dtype()) work_item_idx.dtype = dim.counter.get_dtype() @@ -151,7 +151,7 @@ class SyclPlatform(GenericGpu): id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type)) zero = PsExpression.make(PsConstant(0, self._ctx.index_dtype)) - subscript = PsSubscript(id_symbol, zero) + subscript = PsSubscript(id_symbol, (zero,)) ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype()) subscript.dtype = ispace.sparse_counter.get_dtype() diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index ccaf9fbe9..0c6f6883d 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -7,7 +7,7 @@ from ..ast.expressions import ( PsExpression, PsVectorArrayAccess, PsAddressOf, - PsSubscript, + PsMemAcc, ) from ..transformations.select_intrinsics import IntrinsicOps from ...types import PsCustomType, PsVectorType, PsPointerType @@ -145,7 +145,7 @@ class X86VectorCpu(GenericVectorCpu): if acc.stride == 1: load_func = _x86_packed_load(self._vector_arch, acc.dtype, False) return load_func( - PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)) + PsAddressOf(PsMemAcc(PsExpression.make(acc.base_ptr), acc.index)) ) else: raise NotImplementedError("Gather loads not implemented yet.") @@ -154,7 +154,7 @@ class X86VectorCpu(GenericVectorCpu): if acc.stride == 1: store_func = _x86_packed_store(self._vector_arch, acc.dtype, False) return store_func( - PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)), + PsAddressOf(PsMemAcc(PsExpression.make(acc.base_ptr), acc.index)), arg, ) else: diff --git a/src/pystencils/backend/transformations/erase_anonymous_structs.py b/src/pystencils/backend/transformations/erase_anonymous_structs.py index 03d79a689..7404abd94 100644 --- a/src/pystencils/backend/transformations/erase_anonymous_structs.py +++ b/src/pystencils/backend/transformations/erase_anonymous_structs.py @@ -8,7 +8,7 @@ from ..ast.expressions import ( PsArrayAccess, PsLookup, PsExpression, - PsDeref, + PsMemAcc, PsAddressOf, PsCast, ) @@ -99,8 +99,9 @@ class EraseAnonymousStructTypes: ) type_erased_access = PsArrayAccess(type_erased_bp, byte_index) - deref = PsDeref( - PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access)) + deref = PsMemAcc( + PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access)), + PsExpression.make(PsConstant(0)) ) typify = Typifier(self._ctx) diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index 2368868a9..d4dfd3d04 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -9,12 +9,14 @@ from ..ast.expressions import ( PsConstantExpr, PsLiteralExpr, PsCall, - PsDeref, + PsArrayAccess, PsSubscript, + PsLookup, PsUnOp, PsBinOp, PsArrayInitList, ) +from ..ast.util import determine_memory_object from ...types import PsDereferencableType from ..symbols import PsSymbol @@ -48,7 +50,11 @@ class HoistContext: case PsCall(func): return isinstance(func, PsMathFunction) and args_invariant(expr) - case PsSubscript(ptr, _) | PsDeref(ptr): + case PsSubscript() | PsLookup(): + return determine_memory_object(expr)[1] and args_invariant(expr) + + case PsArrayAccess(ptr, _): + # Regular pointer derefs are never invariant, since we cannot reason about aliasing ptr_type = cast(PsDereferencableType, ptr.get_dtype()) return ptr_type.base_type.const and args_invariant(expr) diff --git a/src/pystencils/sympyextensions/pointers.py b/src/pystencils/sympyextensions/pointers.py index c69f9376d..2ebeba7c9 100644 --- a/src/pystencils/sympyextensions/pointers.py +++ b/src/pystencils/sympyextensions/pointers.py @@ -31,3 +31,14 @@ class AddressOf(sp.Function): return PsPointerType(arg_type, restrict=True, const=True) else: raise ValueError(f'pystencils supports only non void pointers. Current address_of type: {self.args[0]}') + + +class mem_acc(sp.Function): + """Memory access through a raw pointer with an offset. + + This function should be used to model offset memory accesses through raw pointers. + """ + + @classmethod + def eval(cls, ptr, offset): + return None diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index 61e3d73fd..e6fc4bb78 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -1,6 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import final, Any, Sequence +from typing import final, Any, Sequence, SupportsIndex from dataclasses import dataclass import numpy as np @@ -98,31 +98,57 @@ class PsPointerType(PsDereferencableType): class PsArrayType(PsDereferencableType): - """C array type of known or unknown size.""" + """Multidimensional array of fixed shape. + + The element type of an array is never const; only the array itself can be. + If ``element_type`` is const, its constness will be removed. + """ def __init__( - self, base_type: PsType, length: int | None = None, const: bool = False + self, element_type: PsType, shape: SupportsIndex | Sequence[SupportsIndex], const: bool = False ): - self._length = length - super().__init__(base_type, const) + from operator import index + if isinstance(shape, SupportsIndex): + shape = (index(shape),) + else: + shape = tuple(index(s) for s in shape) + + if not shape or any(s <= 0 for s in shape): + raise ValueError(f"Invalid array shape: {shape}") + + if isinstance(element_type, PsArrayType): + raise ValueError("Element type of array cannot be another array.") + + element_type = deconstify(element_type) + + self._shape = shape + super().__init__(element_type, const) def __args__(self) -> tuple[Any, ...]: """ - >>> t = PsArrayType(PsBoolType(), 13) + >>> t = PsArrayType(PsBoolType(), (13, 42)) >>> t == PsArrayType(*t.__args__()) True """ - return (self._base_type, self._length) + return (self._base_type, self._shape) @property - def length(self) -> int | None: - return self._length + def shape(self) -> tuple[int, ...]: + """Shape of this array""" + return self._shape + + @property + def dim(self) -> int: + """Dimensionality of this array""" + return len(self._shape) def c_string(self) -> str: - return f"{self._base_type.c_string()} [{str(self._length) if self._length is not None else ''}]" + arr_brackets = "".join(f"[{s}]" for s in self._shape) + const = self._const_string() + return const + self._base_type.c_string() + arr_brackets def __repr__(self) -> str: - return f"PsArrayType(element_type={repr(self._base_type)}, size={self._length}, const={self._const})" + return f"PsArrayType(element_type={repr(self._base_type)}, shape={self._shape}, const={self._const})" class PsStructType(PsType): @@ -131,6 +157,8 @@ class PsStructType(PsType): A struct type is defined by its sequence of members. The struct may optionally have a name, although the code generator currently does not support named structs and treats them the same way as anonymous structs. + + Struct member types cannot be ``const``; if a ``const`` member type is passed, its constness will be removed. """ @dataclass(frozen=True) @@ -138,6 +166,10 @@ class PsStructType(PsType): name: str dtype: PsType + def __post_init__(self): + # Need to use object.__setattr__ because instances are frozen + object.__setattr__(self, "dtype", deconstify(self.dtype)) + @staticmethod def _canonical_members(members: Sequence[PsStructType.Member | tuple[str, PsType]]): return tuple( diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 072882a7b..341b75601 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -10,6 +10,7 @@ from pystencils import ( DynamicType, ) from pystencils.sympyextensions import CastFunc +from pystencils.sympyextensions.pointers import mem_acc from pystencils.backend.ast.structural import ( PsAssignment, @@ -40,6 +41,9 @@ from pystencils.backend.ast.expressions import ( PsAdd, PsMul, PsSub, + PsArrayInitList, + PsSubscript, + PsMemAcc, ) from pystencils.backend.constants import PsConstant from pystencils.backend.functions import PsMathFunction, MathFunctions @@ -353,3 +357,100 @@ def test_add_sub(): expr = freeze(x - 2 * y) assert expr.structurally_equal(PsAdd(x2, PsMul(minus_two, y2))) + + +def test_tuple_array_literals(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + x, y, z = sp.symbols("x, y, z") + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + z2 = PsExpression.make(ctx.get_symbol("z")) + + one = PsExpression.make(PsConstant(1)) + three = PsExpression.make(PsConstant(3)) + four = PsExpression.make(PsConstant(4)) + + arr_literal = freeze(sp.Tuple(3 + y, z, z / 4)) + assert arr_literal.structurally_equal( + PsArrayInitList([three + y2, z2, one / four * z2]) + ) + + +def test_nested_tuples(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + def f(n): + return freeze(sp.sympify(n)) + + shape = (2, 3, 2) + symb_arr = sp.Tuple(((1, 2), (3, 4), (5, 6)), ((5, 6), (7, 8), (9, 10))) + arr_literal = freeze(symb_arr) + + assert isinstance(arr_literal, PsArrayInitList) + assert arr_literal.shape == shape + + assert arr_literal.structurally_equal( + PsArrayInitList( + [ + ((f(1), f(2)), (f(3), f(4)), (f(5), f(6))), + ((f(5), f(6)), (f(7), f(8)), (f(9), f(10))), + ] + ) + ) + + +def test_invalid_arrays(): + ctx = KernelCreationContext() + + freeze = FreezeExpressions(ctx) + # invalid: nonuniform nesting depth + symb_arr = sp.Tuple((3, 32), 14) + with pytest.raises(FreezeError): + _ = freeze(symb_arr) + + # invalid: nonuniform sub-array length + symb_arr = sp.Tuple((3, 32), (14, -7, 3)) + with pytest.raises(FreezeError): + _ = freeze(symb_arr) + + # invalid: empty subarray + symb_arr = sp.Tuple((), (0, -9)) + with pytest.raises(FreezeError): + _ = freeze(symb_arr) + + # invalid: all subarrays empty + symb_arr = sp.Tuple((), ()) + with pytest.raises(FreezeError): + _ = freeze(symb_arr) + + +def test_memory_access(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + ptr = sp.Symbol("ptr") + expr = freeze(mem_acc(ptr, 31)) + + assert isinstance(expr, PsMemAcc) + assert expr.pointer.structurally_equal(PsExpression.make(ctx.get_symbol("ptr"))) + assert expr.offset.structurally_equal(PsExpression.make(PsConstant(31))) + + +def test_indexed(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + x, y, z = sp.symbols("x, y, z") + a = sp.IndexedBase("a") + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + z2 = PsExpression.make(ctx.get_symbol("z")) + a2 = PsExpression.make(ctx.get_symbol("a")) + + expr = freeze(a[x, y, z]) + assert expr.structurally_equal(PsSubscript(a2, (x2, y2, z2))) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 4c6a4d602..5ea2aa15e 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -5,6 +5,7 @@ import numpy as np from typing import cast from pystencils import Assignment, TypedSymbol, Field, FieldType, AddAugmentedAssignment +from pystencils.sympyextensions.pointers import mem_acc from pystencils.backend.ast.structural import ( PsDeclaration, @@ -32,11 +33,12 @@ from pystencils.backend.ast.expressions import ( PsLt, PsCall, PsTernary, + PsMemAcc ) from pystencils.backend.constants import PsConstant from pystencils.backend.functions import CFunction from pystencils.types import constify, create_type, create_numeric_type -from pystencils.types.quick import Fp, Int, Bool, Arr +from pystencils.types.quick import Fp, Int, Bool, Arr, Ptr from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError @@ -64,7 +66,7 @@ def test_typify_simple(): assert isinstance(fasm, PsDeclaration) def check(expr): - assert expr.dtype == constify(ctx.default_dtype) + assert expr.dtype == ctx.default_dtype match expr: case PsConstantExpr(cs): assert cs.value == 2 @@ -82,41 +84,6 @@ def test_typify_simple(): check(fasm.rhs) -def test_rhs_constness(): - default_type = Fp(32) - ctx = KernelCreationContext(default_dtype=default_type) - - freeze = FreezeExpressions(ctx) - typify = Typifier(ctx) - - f = Field.create_generic( - "f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM - ) - f_const = Field.create_generic( - "f_const", - 1, - index_shape=(1,), - dtype=constify(default_type), - field_type=FieldType.CUSTOM, - ) - - x, y, z = sp.symbols("x, y, z") - - # Right-hand sides should always get const types - asm = typify(freeze(Assignment(x, f.absolute_access([0], [0])))) - assert asm.rhs.get_dtype().const - - asm = typify( - freeze( - Assignment( - f.absolute_access([0], [0]), - f.absolute_access([0], [0]) * f_const.absolute_access([0], [0]) * x + y, - ) - ) - ) - assert asm.rhs.get_dtype().const - - def test_lhs_constness(): default_type = Fp(32) ctx = KernelCreationContext(default_dtype=default_type) @@ -136,7 +103,7 @@ def test_lhs_constness(): x, y, z = sp.symbols("x, y, z") - # Assignment RHS may not be const + # Can assign to non-const LHS asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y))) assert not asm.lhs.get_dtype().const @@ -158,7 +125,7 @@ def test_lhs_constness(): q = ctx.get_symbol("q", Fp(32, const=True)) ast = PsDeclaration(PsExpression.make(q), PsExpression.make(q)) ast = typify(ast) - assert ast.lhs.dtype == Fp(32, const=True) + assert ast.lhs.dtype == Fp(32) ast = PsAssignment(PsExpression.make(q), PsExpression.make(q)) with pytest.raises(TypificationError): @@ -200,7 +167,7 @@ def test_default_typing(): expr = typify(expr) def check(expr): - assert expr.dtype == constify(ctx.default_dtype) + assert expr.dtype == ctx.default_dtype match expr: case PsConstantExpr(cs): assert cs.value in (2, 3, -4) @@ -217,6 +184,141 @@ def test_default_typing(): check(expr) +def test_inline_arrays_1d(): + ctx = KernelCreationContext(default_dtype=Fp(32)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x = sp.Symbol("x") + y = TypedSymbol("y", Fp(16)) + idx = TypedSymbol("idx", Int(32)) + + arr: PsArrayInitList = cast(PsArrayInitList, freeze(sp.Tuple(1, 2, 3, 4))) + decl = PsDeclaration(freeze(x), freeze(y) + PsSubscript(arr, (freeze(idx),))) + # The array elements should learn their type from the context, which gets it from `y` + + decl = typify(decl) + assert decl.lhs.dtype == Fp(16) + assert decl.rhs.dtype == Fp(16) + + assert arr.dtype == Arr(Fp(16), (4,)) + for item in arr.items: + assert item.dtype == Fp(16) + + +def test_inline_arrays_3d(): + ctx = KernelCreationContext(default_dtype=Fp(32)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x = sp.Symbol("x") + y = TypedSymbol("y", Fp(16)) + idx = [TypedSymbol(f"idx_{i}", Int(32)) for i in range(3)] + + arr: PsArrayInitList = freeze( + sp.Tuple(((1, 2), (3, 4), (5, 6)), ((5, 6), (7, 8), (9, 10))) + ) + decl = PsDeclaration( + freeze(x), + freeze(y) + PsSubscript(arr, (freeze(idx[0]), freeze(idx[1]), freeze(idx[2]))), + ) + # The array elements should learn their type from the context, which gets it from `y` + + decl = typify(decl) + assert decl.lhs.dtype == Fp(16) + assert decl.rhs.dtype == Fp(16) + + assert arr.dtype == Arr(Fp(16), (2, 3, 2)) + assert arr.shape == (2, 3, 2) + for item in arr.items: + assert item.dtype == Fp(16) + + +def test_array_subscript(): + ctx = KernelCreationContext(default_dtype=Fp(16)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + arr = sp.IndexedBase(TypedSymbol("arr", Arr(Fp(32), (16,)))) + expr = freeze(arr[3]) + expr = typify(expr) + + assert expr.dtype == Fp(32) + + arr = sp.IndexedBase(TypedSymbol("arr2", Arr(Fp(32), (7, 31)))) + expr = freeze(arr[3, 5]) + expr = typify(expr) + + assert expr.dtype == Fp(32) + + +def test_invalid_subscript(): + ctx = KernelCreationContext(default_dtype=Fp(16)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + non_arr = sp.IndexedBase(TypedSymbol("non_arr", Int(64))) + expr = freeze(non_arr[3]) + + with pytest.raises(TypificationError): + expr = typify(expr) + + wrong_shape_arr = sp.IndexedBase( + TypedSymbol("wrong_shape_arr", Arr(Fp(32), (7, 31, 5))) + ) + expr = freeze(wrong_shape_arr[3, 5]) + + with pytest.raises(TypificationError): + expr = typify(expr) + + # raw pointers are not arrays, cannot enter subscript + ptr = sp.IndexedBase( + TypedSymbol("ptr", Ptr(Int(16))) + ) + expr = freeze(ptr[37]) + + with pytest.raises(TypificationError): + expr = typify(expr) + + +def test_mem_acc(): + ctx = KernelCreationContext(default_dtype=Fp(16)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + ptr = TypedSymbol("ptr", Ptr(Int(64))) + idx = TypedSymbol("idx", Int(32)) + + expr = freeze(mem_acc(ptr, idx)) + expr = typify(expr) + + assert isinstance(expr, PsMemAcc) + assert expr.dtype == Int(64) + assert expr.offset.dtype == Int(32) + + +def test_invalid_mem_acc(): + ctx = KernelCreationContext(default_dtype=Fp(16)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + non_ptr = TypedSymbol("non_ptr", Int(64)) + idx = TypedSymbol("idx", Int(32)) + + expr = freeze(mem_acc(non_ptr, idx)) + + with pytest.raises(TypificationError): + _ = typify(expr) + + arr = TypedSymbol("arr", Arr(Int(64), (31,))) + idx = TypedSymbol("idx", Int(32)) + + expr = freeze(mem_acc(arr, idx)) + + with pytest.raises(TypificationError): + _ = typify(expr) + + def test_lhs_inference(): ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64)) freeze = FreezeExpressions(ctx) @@ -232,13 +334,13 @@ def test_lhs_inference(): fasm = typify(freeze(asm)) assert ctx.get_symbol("x").dtype == Fp(32) - assert fasm.lhs.dtype == constify(Fp(32)) + assert fasm.lhs.dtype == Fp(32) asm = Assignment(y, 3 - w) fasm = typify(freeze(asm)) assert ctx.get_symbol("y").dtype == Fp(16) - assert fasm.lhs.dtype == constify(Fp(16)) + assert fasm.lhs.dtype == Fp(16) fasm = PsAssignment(PsExpression.make(ctx.get_symbol("z")), freeze(3 - w)) fasm = typify(fasm) @@ -252,8 +354,38 @@ def test_lhs_inference(): fasm = typify(fasm) assert ctx.get_symbol("r").dtype == Bool() - assert fasm.lhs.dtype == constify(Bool()) - assert fasm.rhs.dtype == constify(Bool()) + assert fasm.lhs.dtype == Bool() + assert fasm.rhs.dtype == Bool() + + +def test_array_declarations(): + ctx = KernelCreationContext(default_dtype=Fp(32)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x, y, z = sp.symbols("x, y, z") + + # Array type fallback to default + arr1 = sp.Symbol("arr1") + decl = freeze(Assignment(arr1, sp.Tuple(1, 2, 3, 4))) + decl = typify(decl) + + assert ctx.get_symbol("arr1").dtype == Arr(Fp(32), (4,)) + + # Array type determined by default-typed symbol + arr2 = sp.Symbol("arr2") + decl = freeze(Assignment(arr2, sp.Tuple((x, y, -7), (3, -2, 51)))) + decl = typify(decl) + + assert ctx.get_symbol("arr2").dtype == Arr(Fp(32), (2, 3)) + + # Array type determined by pre-typed symbol + q = TypedSymbol("q", Fp(16)) + arr3 = sp.Symbol("arr3") + decl = freeze(Assignment(arr3, sp.Tuple((q, 2), (-q, 0.123)))) + decl = typify(decl) + + assert ctx.get_symbol("arr3").dtype == Arr(Fp(16), (2, 2)) def test_erronous_typing(): @@ -292,17 +424,17 @@ def test_invalid_indices(): ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64)) typify = Typifier(ctx) - arr = PsExpression.make(ctx.get_symbol("arr", Arr(Fp(64)))) + arr = PsExpression.make(ctx.get_symbol("arr", Arr(Fp(64), (61,)))) x, y, z = [PsExpression.make(ctx.get_symbol(x)) for x in "xyz"] # Using default-typed symbols as array indices is illegal when the default type is a float - fasm = PsAssignment(PsSubscript(arr, x + y), z) + fasm = PsAssignment(PsSubscript(arr, (x + y,)), z) with pytest.raises(TypificationError): typify(fasm) - fasm = PsAssignment(z, PsSubscript(arr, x + y)) + fasm = PsAssignment(z, PsSubscript(arr, (x + y,))) with pytest.raises(TypificationError): typify(fasm) @@ -394,7 +526,7 @@ def test_typify_bools_and_relations(): expr = PsAnd(PsEq(x, y), PsAnd(true, PsNot(PsOr(p, q)))) expr = typify(expr) - assert expr.dtype == Bool(const=True) + assert expr.dtype == Bool() def test_bool_in_numerical_context(): @@ -418,7 +550,7 @@ def test_typify_conditionals(rel): cond = PsConditional(rel(x, y), PsBlock([])) cond = typify(cond) - assert cond.condition.dtype == Bool(const=True) + assert cond.condition.dtype == Bool() def test_invalid_conditions(): @@ -447,11 +579,11 @@ def test_typify_ternary(): expr = PsTernary(p, x, y) expr = typify(expr) - assert expr.dtype == Fp(32, const=True) + assert expr.dtype == Fp(32) expr = PsTernary(PsAnd(p, q), a, b + a) expr = typify(expr) - assert expr.dtype == Int(32, const=True) + assert expr.dtype == Int(32) expr = PsTernary(PsAnd(p, q), a, x) with pytest.raises(TypificationError): @@ -475,9 +607,9 @@ def test_cfunction(): result = typify(PsCall(threeway, [x, y])) - assert result.get_dtype() == Int(32, const=True) - assert result.args[0].get_dtype() == Fp(32, const=True) - assert result.args[1].get_dtype() == Fp(32, const=True) + assert result.get_dtype() == Int(32) + assert result.args[0].get_dtype() == Fp(32) + assert result.args[1].get_dtype() == Fp(32) with pytest.raises(TypificationError): _ = typify(PsCall(threeway, (x, p))) diff --git a/tests/nbackend/test_ast.py b/tests/nbackend/test_ast.py index 09c63a557..02f03bfa9 100644 --- a/tests/nbackend/test_ast.py +++ b/tests/nbackend/test_ast.py @@ -3,7 +3,8 @@ from pystencils.backend.constants import PsConstant from pystencils.backend.ast.expressions import ( PsExpression, PsCast, - PsDeref, + PsMemAcc, + PsArrayInitList, PsSubscript, ) from pystencils.backend.ast.structural import ( @@ -45,6 +46,10 @@ def test_cloning(): PsConditional( y, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")]) ), + PsArrayInitList([ + [x, y, one + x], + [one, c2, z] + ]), PsPragma("omp parallel for"), PsLoop( x, @@ -58,8 +63,8 @@ def test_cloning(): PsAssignment(x, y), PsPragma("#pragma clang loop vectorize(enable)"), PsStatement( - PsDeref(PsCast(Ptr(Fp(32)), z)) - + PsSubscript(z, one + one + one) + PsMemAcc(PsCast(Ptr(Fp(32)), z), one) + + PsSubscript(z, (one + one + one, y + one)) ), ] ), diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py index 8fb44e748..dc7a86b0b 100644 --- a/tests/nbackend/test_code_printing.py +++ b/tests/nbackend/test_code_printing.py @@ -55,14 +55,15 @@ def test_printing_integer_functions(): PsBitwiseOr, PsBitwiseXor, PsIntDiv, - PsRem + PsRem, ) expr = PsBitwiseAnd( PsBitwiseXor( PsBitwiseXor(j, k), PsBitwiseOr(PsLeftShift(i, PsRightShift(j, k)), PsIntDiv(i, k)), - ) + PsRem(i, k), + ) + + PsRem(i, k), i, ) code = cprint(expr) @@ -154,3 +155,32 @@ def test_ternary(): expr = PsTernary(PsTernary(p, q, PsOr(p, q)), x, y) code = cprint(expr) assert code == "(p ? q : p || q) ? x : y" + + +def test_arrays(): + import sympy as sp + from pystencils import Assignment + from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory + + ctx = KernelCreationContext(default_dtype=SInt(32)) + factory = AstFactory(ctx) + cprint = CAstPrinter() + + arr_1d = factory.parse_sympy(Assignment(sp.Symbol("a1d"), sp.Tuple(1, 2, 3, 4, 5))) + code = cprint(arr_1d) + assert code == "int32_t a1d[5] = { 1, 2, 3, 4, 5 };" + + arr_2d = factory.parse_sympy( + Assignment(sp.Symbol("a2d"), sp.Tuple((1, -1), (2, -2))) + ) + code = cprint(arr_2d) + assert code == "int32_t a2d[2][2] = { { 1, -1 }, { 2, -2 } };" + + arr_3d = factory.parse_sympy( + Assignment(sp.Symbol("a3d"), sp.Tuple(((1, -1), (2, -2)), ((3, -3), (4, -4)))) + ) + code = cprint(arr_3d) + assert ( + code + == "int32_t a3d[2][2][2] = { { { 1, -1 }, { 2, -2 } }, { { 3, -3 }, { 4, -4 } } };" + ) diff --git a/tests/nbackend/test_extensions.py b/tests/nbackend/test_extensions.py index 16e610a55..914d05594 100644 --- a/tests/nbackend/test_extensions.py +++ b/tests/nbackend/test_extensions.py @@ -18,13 +18,13 @@ def test_literals(): f = Field.create_generic("f", 3) x = sp.Symbol("x") - cells = PsExpression.make(PsLiteral("CELLS", Arr(Int(64, const=True), 3))) + cells = PsExpression.make(PsLiteral("CELLS", Arr(Int(64), (3,), const=True))) global_constant = PsExpression.make(PsLiteral("C", ctx.default_dtype)) loop_slice = make_slice[ - 0:PsSubscript(cells, factory.parse_index(0)), - 0:PsSubscript(cells, factory.parse_index(1)), - 0:PsSubscript(cells, factory.parse_index(2)), + 0:PsSubscript(cells, (factory.parse_index(0),)), + 0:PsSubscript(cells, (factory.parse_index(1),)), + 0:PsSubscript(cells, (factory.parse_index(2),)), ] ispace = FullIterationSpace.create_from_slice(ctx, loop_slice) diff --git a/tests/nbackend/transformations/test_canonical_clone.py b/tests/nbackend/transformations/test_canonical_clone.py index b158b9178..b5e100ea5 100644 --- a/tests/nbackend/transformations/test_canonical_clone.py +++ b/tests/nbackend/transformations/test_canonical_clone.py @@ -22,8 +22,8 @@ def test_clone_entire_ast(): rho = sp.Symbol("rho") u = sp.symbols("u_:2") - cx = TypedSymbol("cx", Arr(ctx.default_dtype)) - cy = TypedSymbol("cy", Arr(ctx.default_dtype)) + cx = TypedSymbol("cx", Arr(ctx.default_dtype, (5,))) + cy = TypedSymbol("cy", Arr(ctx.default_dtype, (5,))) cxs = sp.IndexedBase(cx, shape=(5,)) cys = sp.IndexedBase(cy, shape=(5,)) diff --git a/tests/nbackend/transformations/test_hoist_invariants.py b/tests/nbackend/transformations/test_hoist_invariants.py index 15514f1da..daa2760c0 100644 --- a/tests/nbackend/transformations/test_hoist_invariants.py +++ b/tests/nbackend/transformations/test_hoist_invariants.py @@ -144,14 +144,14 @@ def test_hoist_arrays(): const_arr_symb = TypedSymbol( "const_arr", - Arr(Fp(64, const=True), 10), + Arr(Fp(64), (10,), const=True), ) const_array_decl = factory.parse_sympy(Assignment(const_arr_symb, tuple(range(10)))) const_arr = sp.IndexedBase(const_arr_symb, shape=(10,)) arr_symb = TypedSymbol( "arr", - Arr(Fp(64, const=False), 10), + Arr(Fp(64), (10,), const=False), ) array_decl = factory.parse_sympy(Assignment(arr_symb, tuple(range(10)))) arr = sp.IndexedBase(arr_symb, shape=(10,)) diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py index 1cc2ae0e4..165d572de 100644 --- a/tests/nbackend/types/test_types.py +++ b/tests/nbackend/types/test_types.py @@ -151,6 +151,48 @@ def test_struct_types(): assert t.itemsize == numpy_type.itemsize == 16 +def test_array_types(): + t = PsArrayType(UInt(64), 42) + assert t.dim == 1 + assert t.shape == (42,) + assert not t.const + assert t.c_string() == "uint64_t[42]" + + assert t == PsArrayType(UInt(64), (42,)) + + t = PsArrayType(UInt(64), [3, 4, 5]) + assert t.dim == 3 + assert t.shape == (3, 4, 5) + assert not t.const + assert t.c_string() == "uint64_t[3][4][5]" + + t = PsArrayType(UInt(64, const=True), [3, 4, 5]) + assert t.dim == 3 + assert t.shape == (3, 4, 5) + assert not t.const + + t = PsArrayType(UInt(64), [3, 4, 5], const=True) + assert t.dim == 3 + assert t.shape == (3, 4, 5) + assert t.const + assert t.c_string() == "const uint64_t[3][4][5]" + + t = PsArrayType(UInt(64, const=True), [3, 4, 5], const=True) + assert t.dim == 3 + assert t.shape == (3, 4, 5) + assert t.const + + with pytest.raises(ValueError): + _ = PsArrayType(UInt(64), (3, 0, 1)) + + with pytest.raises(ValueError): + _ = PsArrayType(UInt(64), (3, 9, -1, 2)) + + # Nested arrays are disallowed + with pytest.raises(ValueError): + _ = PsArrayType(PsArrayType(Bool(), (2,)), (3, 1)) + + def test_pickle(): types = [ Bool(const=True), @@ -165,8 +207,8 @@ def test_pickle(): Fp(width=16, const=True), PsStructType([("x", UInt(32)), ("y", UInt(32)), ("val", Fp(64))], "myStruct"), PsStructType([("data", Fp(32))], "None"), - PsArrayType(Fp(16, const=True), 42), - PsArrayType(PsVectorType(Fp(32), 8, const=False), 42) + PsArrayType(Fp(16), (42,), const=True), + PsArrayType(PsVectorType(Fp(32), 8), (42,)) ] dumped = pickle.dumps(types) -- GitLab