diff --git a/docs/source/api/symbolic_language/astnodes.rst b/docs/source/api/symbolic_language/astnodes.rst index 4d5c4b89f410ba7bcbde695819eb4b7351fcd71b..ff31c98ecbb5822ef3b1fba8b18f577f2c352e0e 100644 --- a/docs/source/api/symbolic_language/astnodes.rst +++ b/docs/source/api/symbolic_language/astnodes.rst @@ -4,6 +4,6 @@ Kernel Structure .. automodule:: pystencils.sympyextensions.astnodes -.. autoclass:: pystencils.sympyextensions.AssignmentCollection +.. autoclass:: pystencils.AssignmentCollection :members: diff --git a/docs/source/backend/ast.rst b/docs/source/backend/ast.rst index 41f23016664002fd100544d72c509f9f73d72bdd..44f8f25409ad08cff1d550dddcd9099bbbc798df 100644 --- a/docs/source/backend/ast.rst +++ b/docs/source/backend/ast.rst @@ -2,29 +2,30 @@ Abstract Syntax Tree ******************** -Inheritance Diagramm -==================== +API Documentation +================= + +Inheritance Diagram +------------------- .. inheritance-diagram:: pystencils.backend.ast.astnode.PsAstNode pystencils.backend.ast.structural pystencils.backend.ast.expressions pystencils.backend.extensions.foreign_ast :top-classes: pystencils.types.PsAstNode :parts: 1 - Base Classes -============ +------------ .. automodule:: pystencils.backend.ast.astnode :members: Structural Nodes -================ +---------------- .. automodule:: pystencils.backend.ast.structural :members: - Expressions -=========== +----------- .. automodule:: pystencils.backend.ast.expressions :members: diff --git a/docs/source/backend/index.rst b/docs/source/backend/index.rst index f2fe9346dbe4d38722b69dd9c279d0eb11c98773..70ed684c69e494a2c9d72f91fdcc62641f4acd69 100644 --- a/docs/source/backend/index.rst +++ b/docs/source/backend/index.rst @@ -15,6 +15,7 @@ who wish to customize or extend the behaviour of the code generator in their app translation platforms transformations + output jit extensions @@ -30,7 +31,7 @@ The IR comprises *symbols*, *constants*, *arrays*, the *iteration space* and the * `PsSymbol` represents a single symbol in the kernel, annotated with a type. Other than in the frontend, uniqueness of symbols is enforced by the backend: of each symbol, at most one instance may exist. * `PsConstant` provides a type-safe representation of constants. -* `PsLinearizedArray` is the backend counterpart to the ubiquitous `Field`, representing a contiguous +* `PsBuffer` is the backend counterpart to the ubiquitous `Field`, representing a contiguous n-dimensional array. These arrays do not occur directly in the IR, but are represented through their *associated symbols*, which are base pointers, shapes, and strides. diff --git a/docs/source/backend/objects.rst b/docs/source/backend/objects.rst index b0c3af6db67ff3cfb1e6a3d3603e84e6c4abb6cb..11cf8ea5e53446a0db9dfac18c983e46eaf3bf36 100644 --- a/docs/source/backend/objects.rst +++ b/docs/source/backend/objects.rst @@ -1,15 +1,123 @@ -***************************** -Symbols, Constants and Arrays -***************************** +**************************** +Constants and Memory Objects +**************************** -.. autoclass:: pystencils.backend.symbols.PsSymbol +Memory Objects: Symbols and Buffers +=================================== + +The Memory Model +---------------- + +In order to reason about memory accesses, mutability, invariance, and aliasing, the *pystencils* backend uses +a very simple memory model. There are three types of memory objects: + +- Symbols (`PsSymbol`), which act as registers for data storage within the scope of a kernel +- Field buffers (`PsBuffer`), which represent a contiguous block of memory the kernel has access to, and +- the *unmanaged heap*, which is a global catch-all memory object which all pointers not belonging to a field + array point into. + +All of these objects are disjoint, and cannot alias each other. +Each symbol exists in isolation, +field buffers do not overlap, +and raw pointers are assumed not to point into memory owned by a symbol or field array. +Instead, all raw pointers point into unmanaged heap memory, and are assumed to *always* alias one another: +Each change brought to unmanaged memory by one raw pointer is assumed to affect the memory pointed to by +another raw pointer. + +Symbols +------- + +In the pystencils IR, instances of `PsSymbol` represent what is generally known as "virtual registers". +These are memory locations that are private to a function, cannot be aliased or pointed to, and will finally reside +either in physical registers or on the stack. +Each symbol has a name and a data type. The data type may initially be `None`, in which case it should soon after be +determined by the `Typifier`. + +Other than their front-end counterpart `sympy.Symbol <sympy.core.symbol.Symbol>`, `PsSymbol` instances are mutable; +their properties can and often will change over time. +As a consequence, they are not comparable by value: +two `PsSymbol` instances with the same name and data type will in general *not* be equal. +In fact, most of the time, it is an error to have two identical symbol instances active. + +Creating Symbols +^^^^^^^^^^^^^^^^ + +During kernel translation, symbols never exist in isolation, but should always be managed by a `KernelCreationContext`. +Symbols can be created and retrieved using `add_symbol <KernelCreationContext.add_symbol>` and `find_symbol <KernelCreationContext.find_symbol>`. +A symbol can also be duplicated using `duplicate_symbol <KernelCreationContext.duplicate_symbol>`, which assigns a new name to the symbol's copy. +The `KernelCreationContext` keeps track of all existing symbols during a kernel translation run +and makes sure that no name and data type conflicts may arise. + +Never call the constructor of `PsSymbol` directly unless you really know what you are doing. + +Symbol Properties +^^^^^^^^^^^^^^^^^ + +Symbols can be annotated with arbitrary information using *symbol properties*. +Each symbol property type must be a subclass of `PsSymbolProperty`. +It is strongly recommended to implement property types using frozen +`dataclasses <https://docs.python.org/3/library/dataclasses.html>`_. +For example, this snippet defines a property type that models pointer alignment requirements: + +.. code-block:: python + + @dataclass(frozen=True) + class AlignmentProperty(UniqueSymbolProperty) + """Require this pointer symbol to be aligned at a particular byte boundary.""" + + byte_boundary: int + +Inheriting from `UniqueSymbolProperty` ensures that at most one property of this type can be attached to +a symbol at any time. +Properties can be added, queried, and removed using the `PsSymbol` properties API listed below. + +Many symbol properties are more relevant to consumers of generated kernels than to the code generator itself. +The above alignment property, for instance, may be added to a pointer symbol by a vectorization pass +to document its assumption that the pointer be properly aligned, in order to emit aligned load and store instructions. +It then becomes the responsibility of the runtime system embedding the kernel to check this prequesite before calling the kernel. +To make sure this information becomes visible, any properties attached to symbols exposed as kernel parameters will also +be added to their respective `KernelParameter` instance. + +Buffers +------- + +Buffers, as represented by the `PsBuffer` class, represent contiguous, n-dimensional, linearized cuboid blocks of memory. +Each buffer has a fixed name and element data type, +and will be represented in the IR via three sets of symbols: + +- The *base pointer* is a symbol of pointer type which points into the buffer's underlying memory area. + Each buffer has at least one, its primary base pointer, whose pointed-to type must be the same as the + buffer's element type. There may be additional base pointers pointing into subsections of that memory. + These additional base pointers may also have deviating data types, as is for instance required for + type erasure in certain cases. + To communicate its role to the code generation system, + each base pointer needs to be marked as such using the `BufferBasePtr` property, + . +- The buffer *shape* defines the size of the buffer in each dimension. Each shape entry is either a `symbol <PsSymbol>` + or a `constant <PsConstant>`. +- The buffer *strides* define the step size to go from one entry to the next in each dimension. + Like the shape, each stride entry is also either a symbol or a constant. + +The shape and stride symbols must all have the same data type, which will be stored as the buffer's index data type. + +Creating and Managing Buffers +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Similarily to symbols, buffers are typically managed by the `KernelCreationContext`, which associates each buffer +to a front-end `Field`. Buffers for fields can be obtained using `get_buffer <KernelCreationContext.get_buffer>`. +The context makes sure to avoid name conflicts between buffers. + +API Documentation +================= + +.. automodule:: pystencils.backend.properties :members: -.. autoclass:: pystencils.backend.constants.PsConstant +.. automodule:: pystencils.backend.memory :members: -.. autoclass:: pystencils.backend.literals.PsLiteral +.. automodule:: pystencils.backend.constants :members: -.. automodule:: pystencils.backend.arrays +.. autoclass:: pystencils.backend.literals.PsLiteral :members: diff --git a/docs/source/backend/output.rst b/docs/source/backend/output.rst new file mode 100644 index 0000000000000000000000000000000000000000..9875e257b0ff76de29370dd40a0ab0772ad1fec1 --- /dev/null +++ b/docs/source/backend/output.rst @@ -0,0 +1,6 @@ +********************* +Code Generator Output +********************* + +.. automodule:: pystencils.backend.kernelfunction + :members: diff --git a/src/pystencils/backend/__init__.py b/src/pystencils/backend/__init__.py index a0b1c8f747984e3fffde5a336f40e2aa46ad631d..b947a112ecb2be7762fefdf54afd4dffc185c319 100644 --- a/src/pystencils/backend/__init__.py +++ b/src/pystencils/backend/__init__.py @@ -1,9 +1,5 @@ from .kernelfunction import ( KernelParameter, - FieldParameter, - FieldShapeParam, - FieldStrideParam, - FieldPointerParam, KernelFunction, GpuKernelFunction, ) @@ -12,10 +8,6 @@ from .constraints import KernelParamsConstraint __all__ = [ "KernelParameter", - "FieldParameter", - "FieldShapeParam", - "FieldStrideParam", - "FieldPointerParam", "KernelFunction", "GpuKernelFunction", "KernelParamsConstraint", diff --git a/src/pystencils/backend/arrays.py b/src/pystencils/backend/arrays.py deleted file mode 100644 index 9aefeaf62fe65fad045a6d2f20e72ce732f8aa8b..0000000000000000000000000000000000000000 --- a/src/pystencils/backend/arrays.py +++ /dev/null @@ -1,194 +0,0 @@ -from __future__ import annotations - -from typing import Sequence -from types import EllipsisType - -from abc import ABC - -from .constants import PsConstant -from ..types import ( - PsType, - PsPointerType, - PsIntegerType, - PsUnsignedIntegerType, -) - -from .symbols import PsSymbol -from ..defaults import DEFAULTS - - -class PsLinearizedArray: - """Class to model N-dimensional contiguous arrays. - - **Memory Layout, Shape and Strides** - - The memory layout of an array is defined by its shape and strides. - Both shape and stride entries may either be constants or special variables associated with - exactly one array. - - Shape and strides may be specified at construction in the following way. - For constant entries, their value must be given as an integer. - For variable shape entries and strides, the Ellipsis `...` must be passed instead. - Internally, the passed ``index_dtype`` will be used to create typed constants (`PsConstant`) - and variables (`PsArrayShapeSymbol` and `PsArrayStrideSymbol`) from the passed values. - """ - - def __init__( - self, - name: str, - element_type: PsType, - shape: Sequence[int | str | EllipsisType], - strides: Sequence[int | str | EllipsisType], - index_dtype: PsIntegerType = DEFAULTS.index_dtype, - ): - self._name = name - self._element_type = element_type - self._index_dtype = index_dtype - - if len(shape) != len(strides): - raise ValueError("Shape and stride tuples must have the same length") - - def make_shape(coord, name_or_val): - match name_or_val: - case EllipsisType(): - return PsArrayShapeSymbol(DEFAULTS.field_shape_name(name, coord), self, coord) - case str(): - return PsArrayShapeSymbol(name_or_val, self, coord) - case _: - return PsConstant(name_or_val, index_dtype) - - self._shape: tuple[PsArrayShapeSymbol | PsConstant, ...] = tuple( - make_shape(i, s) for i, s in enumerate(shape) - ) - - def make_stride(coord, name_or_val): - match name_or_val: - case EllipsisType(): - return PsArrayStrideSymbol(DEFAULTS.field_stride_name(name, coord), self, coord) - case str(): - return PsArrayStrideSymbol(name_or_val, self, coord) - case _: - return PsConstant(name_or_val, index_dtype) - - self._strides: tuple[PsArrayStrideSymbol | PsConstant, ...] = tuple( - make_stride(i, s) for i, s in enumerate(strides) - ) - - self._base_ptr = PsArrayBasePointer(DEFAULTS.field_pointer_name(name), self) - - @property - def name(self): - """The array's name""" - return self._name - - @property - def base_pointer(self) -> PsArrayBasePointer: - """The array's base pointer""" - return self._base_ptr - - @property - def shape(self) -> tuple[PsArrayShapeSymbol | PsConstant, ...]: - """The array's shape, expressed using `PsConstant` and `PsArrayShapeSymbol`""" - return self._shape - - @property - def strides(self) -> tuple[PsArrayStrideSymbol | PsConstant, ...]: - """The array's strides, expressed using `PsConstant` and `PsArrayStrideSymbol`""" - return self._strides - - @property - def index_type(self) -> PsIntegerType: - return self._index_dtype - - @property - def element_type(self) -> PsType: - return self._element_type - - def __repr__(self) -> str: - return ( - f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])" - ) - - -class PsArrayAssocSymbol(PsSymbol, ABC): - """A variable that is associated to an array. - - Instances of this class represent pointers and indexing information bound - to a particular array. - """ - - __match_args__ = ("name", "dtype", "array") - - def __init__(self, name: str, dtype: PsType, array: PsLinearizedArray): - super().__init__(name, dtype) - self._array = array - - @property - def array(self) -> PsLinearizedArray: - return self._array - - -class PsArrayBasePointer(PsArrayAssocSymbol): - def __init__(self, name: str, array: PsLinearizedArray): - dtype = PsPointerType(array.element_type) - super().__init__(name, dtype, array) - - self._array = array - - -class TypeErasedBasePointer(PsArrayBasePointer): - """Base pointer for arrays whose element type has been erased. - - Used primarily for arrays of anonymous structs.""" - - def __init__(self, name: str, array: PsLinearizedArray): - dtype = PsPointerType(PsUnsignedIntegerType(8)) - super(PsArrayBasePointer, self).__init__(name, dtype, array) - - self._array = array - - -class PsArrayShapeSymbol(PsArrayAssocSymbol): - """Variable that represents an array's shape in one coordinate. - - Do not instantiate this class yourself, but only use its instances - as provided by `PsLinearizedArray.shape`. - """ - - __match_args__ = PsArrayAssocSymbol.__match_args__ + ("coordinate",) - - def __init__( - self, - name: str, - array: PsLinearizedArray, - coordinate: int, - ): - super().__init__(name, array.index_type, array) - self._coordinate = coordinate - - @property - def coordinate(self) -> int: - return self._coordinate - - -class PsArrayStrideSymbol(PsArrayAssocSymbol): - """Variable that represents an array's stride in one coordinate. - - Do not instantiate this class yourself, but only use its instances - as provided by `PsLinearizedArray.strides`. - """ - - __match_args__ = PsArrayAssocSymbol.__match_args__ + ("coordinate",) - - def __init__( - self, - name: str, - array: PsLinearizedArray, - coordinate: int, - ): - super().__init__(name, array.index_type, array) - self._coordinate = coordinate - - @property - def coordinate(self) -> int: - return self._coordinate diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py index 5b37470cc89dc8b304d9fd033ea66ec67b9f8d8a..3c6d2ef557e44a882edf4e104df4bd4e2a8830fd 100644 --- a/src/pystencils/backend/ast/analysis.py +++ b/src/pystencils/backend/ast/analysis.py @@ -16,7 +16,7 @@ from .structural import ( ) from .expressions import ( PsAdd, - PsArrayAccess, + PsBufferAcc, PsCall, PsConstantExpr, PsDiv, @@ -28,9 +28,11 @@ from .expressions import ( PsSub, PsSymbolExpr, PsTernary, + PsSubscript, + PsMemAcc ) -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..exceptions import PsInternalCompilerError from ...types import PsNumericType @@ -282,8 +284,14 @@ class OperationCounter: case PsSymbolExpr(_) | PsConstantExpr(_) | PsLiteralExpr(_): return OperationCounts() - case PsArrayAccess(_, index): - return self.visit_expr(index) + case PsBufferAcc(_, indices) | PsSubscript(_, indices): + return reduce( + operator.add, + (self.visit_expr(idx) for idx in indices) + ) + + case PsMemAcc(_, offset): + return self.visit_expr(offset) case PsCall(_, args): return OperationCounts(calls=1) + reduce( diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 5f9c95d5d53d824620b030f7a618b79e8b81564f..d73b1faa758f8ce31312c674712ec89bfd5683ab 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -1,18 +1,19 @@ from __future__ import annotations + from abc import ABC, abstractmethod from typing import Sequence, overload, Callable, Any, cast import operator -from ..symbols import PsSymbol +import numpy as np +from numpy.typing import NDArray + +from ..memory import PsSymbol, PsBuffer, BufferBasePtr from ..constants import PsConstant from ..literals import PsLiteral -from ..arrays import PsLinearizedArray, PsArrayBasePointer from ..functions import PsFunction from ...types import ( PsType, - PsScalarType, PsVectorType, - PsTypeError, ) from .util import failing_cast from ..exceptions import PsInternalCompilerError @@ -33,6 +34,10 @@ class PsExpression(PsAstNode, ABC): The type annotations are used by various transformation passes to make decisions, e.g. in function materialization and intrinsic selection. + + .. attention:: + The ``structurally_equal`` check currently does not take expression data types into + account. This may change in the future. """ def __init__(self, dtype: PsType | None = None) -> None: @@ -93,13 +98,32 @@ class PsExpression(PsAstNode, ABC): else: raise ValueError(f"Cannot make expression out of {obj}") + def clone(self): + """Clone this expression. + + .. note:: + Subclasses of `PsExpression` should not override this method, + but implement `_clone_expr` instead. + That implementation shall call `clone` on any of its subexpressions, + but does not need to fix the `dtype` property. + The `dtype` is correctly applied by `PsExpression.clone` internally. + """ + cloned = self._clone_expr() + cloned._dtype = self.dtype + return cloned + @abstractmethod - def clone(self) -> PsExpression: + def _clone_expr(self) -> PsExpression: + """Implementation of expression cloning. + + :meta public: + """ pass class PsLvalue(ABC): - """Mix-in for all expressions that may occur as an lvalue""" + """Mix-in for all expressions that may occur as an lvalue; + i.e. expressions that represent a memory location.""" class PsSymbolExpr(PsLeafMixIn, PsLvalue, PsExpression): @@ -119,7 +143,7 @@ class PsSymbolExpr(PsLeafMixIn, PsLvalue, PsExpression): def symbol(self, symbol: PsSymbol): self._symbol = symbol - def clone(self) -> PsSymbolExpr: + def _clone_expr(self) -> PsSymbolExpr: return PsSymbolExpr(self._symbol) def structurally_equal(self, other: PsAstNode) -> bool: @@ -147,7 +171,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): def constant(self, c: PsConstant): self._constant = c - def clone(self) -> PsConstantExpr: + def _clone_expr(self) -> PsConstantExpr: return PsConstantExpr(self._constant) def structurally_equal(self, other: PsAstNode) -> bool: @@ -175,7 +199,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression): def literal(self, lit: PsLiteral): self._literal = lit - def clone(self) -> PsLiteralExpr: + def _clone_expr(self) -> PsLiteralExpr: return PsLiteralExpr(self._literal) def structurally_equal(self, other: PsAstNode) -> bool: @@ -188,117 +212,178 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression): return f"PsLiteralExpr({repr(self._literal)})" +class PsBufferAcc(PsLvalue, PsExpression): + """Access into a `PsBuffer`.""" + + __match_args__ = ("base_pointer", "index") + + def __init__(self, base_ptr: PsSymbol, index: Sequence[PsExpression]): + super().__init__() + bptr_prop = cast(BufferBasePtr, base_ptr.get_properties(BufferBasePtr).pop()) + + if len(index) != bptr_prop.buffer.dim: + raise ValueError("Number of index expressions must equal buffer shape.") + + self._base_ptr = PsExpression.make(base_ptr) + self._index = list(index) + self._dtype = bptr_prop.buffer.element_type + + @property + def base_pointer(self) -> PsSymbolExpr: + return self._base_ptr + + @base_pointer.setter + def base_pointer(self, expr: PsSymbolExpr): + bptr_prop = cast(BufferBasePtr, expr.symbol.get_properties(BufferBasePtr).pop()) + if bptr_prop.buffer != self.buffer: + raise ValueError( + "Cannot replace a buffer access's base pointer with one belonging to a different buffer." + ) + + self._base_ptr = expr + + @property + def buffer(self) -> PsBuffer: + return cast( + BufferBasePtr, self._base_ptr.symbol.get_properties(BufferBasePtr).pop() + ).buffer + + @property + def index(self) -> list[PsExpression]: + return self._index + + def get_children(self) -> tuple[PsAstNode, ...]: + return (self._base_ptr,) + tuple(self._index) + + def set_child(self, idx: int, c: PsAstNode): + idx = range(len(self._index) + 1)[idx] + if idx == 0: + self.base_pointer = failing_cast(PsSymbolExpr, c) + else: + self._index[idx - 1] = failing_cast(PsExpression, c) + + def _clone_expr(self) -> PsBufferAcc: + return PsBufferAcc(self._base_ptr.symbol, [i.clone() for i in self._index]) + + def __repr__(self) -> str: + return f"PsBufferAcc({repr(self._base_ptr)}, {repr(self._index)})" + + class PsSubscript(PsLvalue, PsExpression): - __match_args__ = ("base", "index") + """N-dimensional subscript into an array.""" + + __match_args__ = ("array", "index") - def __init__(self, base: PsExpression, index: PsExpression): + def __init__(self, arr: PsExpression, index: Sequence[PsExpression]): super().__init__() - self._base = base - self._index = index + self._arr = arr + + if not index: + raise ValueError("Subscript index cannot be empty.") + + self._index = list(index) @property - def base(self) -> PsExpression: - return self._base + def array(self) -> PsExpression: + return self._arr - @base.setter - def base(self, expr: PsExpression): - self._base = expr + @array.setter + def array(self, expr: PsExpression): + self._arr = expr @property - def index(self) -> PsExpression: + def index(self) -> list[PsExpression]: return self._index @index.setter - def index(self, expr: PsExpression): - self._index = expr + def index(self, idx: Sequence[PsExpression]): + self._index = list(idx) - def clone(self) -> PsSubscript: - return PsSubscript(self._base.clone(), self._index.clone()) + def _clone_expr(self) -> PsSubscript: + return PsSubscript(self._arr.clone(), [i.clone() for i in self._index]) def get_children(self) -> tuple[PsAstNode, ...]: - return (self._base, self._index) + return (self._arr,) + tuple(self._index) def set_child(self, idx: int, c: PsAstNode): - idx = [0, 1][idx] + idx = range(len(self._index) + 1)[idx] match idx: case 0: - self.base = failing_cast(PsExpression, c) - case 1: - self.index = failing_cast(PsExpression, c) + self.array = failing_cast(PsExpression, c) + case _: + self.index[idx - 1] = failing_cast(PsExpression, c) def __repr__(self) -> str: - return f"Subscript({self._base})[{self._index}]" + idx = ", ".join(repr(i) for i in self._index) + return f"PsSubscript({repr(self._arr)}, {repr(idx)})" -class PsArrayAccess(PsSubscript): - __match_args__ = ("base_ptr", "index") +class PsMemAcc(PsLvalue, PsExpression): + """Pointer-based memory access with type-dependent offset.""" - def __init__(self, base_ptr: PsArrayBasePointer, index: PsExpression): - super().__init__(PsExpression.make(base_ptr), index) - self._base_ptr = base_ptr - self._dtype = base_ptr.array.element_type + __match_args__ = ("pointer", "offset") + + def __init__(self, ptr: PsExpression, offset: PsExpression): + super().__init__() + self._ptr = ptr + self._offset = offset @property - def base_ptr(self) -> PsArrayBasePointer: - return self._base_ptr + def pointer(self) -> PsExpression: + return self._ptr + + @pointer.setter + def pointer(self, expr: PsExpression): + self._ptr = expr @property - def base(self) -> PsExpression: - return self._base + def offset(self) -> PsExpression: + return self._offset - @base.setter - def base(self, expr: PsExpression): - if not isinstance(expr, PsSymbolExpr) or not isinstance( - expr.symbol, PsArrayBasePointer - ): - raise ValueError( - "Base expression of PsArrayAccess must be an array base pointer" - ) + @offset.setter + def offset(self, expr: PsExpression): + self._offset = expr - self._base_ptr = expr.symbol - self._base = expr + def _clone_expr(self) -> PsMemAcc: + return PsMemAcc(self._ptr.clone(), self._offset.clone()) - @property - def array(self) -> PsLinearizedArray: - return self._base_ptr.array + def get_children(self) -> tuple[PsAstNode, ...]: + return (self._ptr, self._offset) - def clone(self) -> PsArrayAccess: - return PsArrayAccess(self._base_ptr, self._index.clone()) + def set_child(self, idx: int, c: PsAstNode): + idx = [0, 1][idx] + match idx: + case 0: + self.pointer = failing_cast(PsExpression, c) + case 1: + self.offset = failing_cast(PsExpression, c) def __repr__(self) -> str: - return f"ArrayAccess({repr(self._base_ptr)}, {repr(self._index)})" + return f"PsMemAcc({repr(self._ptr)}, {repr(self._offset)})" + +class PsVectorMemAcc(PsMemAcc): + """Pointer-based vectorized memory access.""" -class PsVectorArrayAccess(PsArrayAccess): __match_args__ = ("base_ptr", "base_index") def __init__( self, - base_ptr: PsArrayBasePointer, + base_ptr: PsExpression, base_index: PsExpression, vector_entries: int, stride: int = 1, alignment: int = 0, ): super().__init__(base_ptr, base_index) - element_type = base_ptr.array.element_type - if not isinstance(element_type, PsScalarType): - raise PsTypeError( - "Cannot generate vector accesses to arrays with non-scalar elements" - ) - - self._vector_type = PsVectorType( - element_type, vector_entries, const=element_type.const - ) + self._vector_entries = vector_entries self._stride = stride self._alignment = alignment - self._dtype = self._vector_type - @property def vector_entries(self) -> int: - return self._vector_type.vector_entries + return self._vector_entries @property def stride(self) -> int: @@ -311,22 +396,22 @@ class PsVectorArrayAccess(PsArrayAccess): def get_vector_type(self) -> PsVectorType: return cast(PsVectorType, self._dtype) - def clone(self) -> PsVectorArrayAccess: - return PsVectorArrayAccess( - self._base_ptr, - self._index.clone(), + def _clone_expr(self) -> PsVectorMemAcc: + return PsVectorMemAcc( + self._ptr.clone(), + self._offset.clone(), self.vector_entries, self._stride, self._alignment, ) def structurally_equal(self, other: PsAstNode) -> bool: - if not isinstance(other, PsVectorArrayAccess): + if not isinstance(other, PsVectorMemAcc): return False return ( super().structurally_equal(other) - and self._vector_type == other._vector_type + and self._vector_entries == other._vector_entries and self._stride == other._stride and self._alignment == other._alignment ) @@ -356,7 +441,7 @@ class PsLookup(PsExpression, PsLvalue): def member_name(self, name: str): self._name = name - def clone(self) -> PsLookup: + def _clone_expr(self) -> PsLookup: return PsLookup(self._aggregate.clone(), self._member_name) def get_children(self) -> tuple[PsAstNode, ...]: @@ -366,6 +451,9 @@ class PsLookup(PsExpression, PsLvalue): idx = [0][idx] self._aggregate = failing_cast(PsExpression, c) + def __repr__(self) -> str: + return f"PsLookup({repr(self._aggregate)}, {repr(self._member_name)})" + class PsCall(PsExpression): __match_args__ = ("function", "args") @@ -406,7 +494,7 @@ class PsCall(PsExpression): self._args = list(exprs) - def clone(self) -> PsCall: + def _clone_expr(self) -> PsCall: return PsCall(self._function, [arg.clone() for arg in self._args]) def get_children(self) -> tuple[PsAstNode, ...]: @@ -450,7 +538,7 @@ class PsTernary(PsExpression): def case_else(self) -> PsExpression: return self._else - def clone(self) -> PsExpression: + def _clone_expr(self) -> PsExpression: return PsTernary(self._cond.clone(), self._then.clone(), self._else.clone()) def get_children(self) -> tuple[PsExpression, ...]: @@ -500,7 +588,7 @@ class PsUnOp(PsExpression): def operand(self, expr: PsExpression): self._operand = expr - def clone(self) -> PsUnOp: + def _clone_expr(self) -> PsUnOp: return type(self)(self._operand.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -525,11 +613,17 @@ class PsNeg(PsUnOp, PsNumericOpTrait): return operator.neg -class PsDeref(PsLvalue, PsUnOp): - pass - - class PsAddressOf(PsUnOp): + """Take the address of a memory location. + + .. DANGER:: + Taking the address of a memory location owned by a symbol or field array + introduces an alias to that memory location. + As pystencils assumes its symbols and fields to never be aliased, this can + subtly change the semantics of a kernel. + Use the address-of operator with utmost care. + """ + pass @@ -548,7 +642,7 @@ class PsCast(PsUnOp): def target_type(self, dtype: PsType): self._target_type = dtype - def clone(self) -> PsUnOp: + def _clone_expr(self) -> PsUnOp: return PsCast(self._target_type, self._operand.clone()) def structurally_equal(self, other: PsAstNode) -> bool: @@ -584,7 +678,7 @@ class PsBinOp(PsExpression): def operand2(self, expr: PsExpression): self._op2 = expr - def clone(self) -> PsBinOp: + def _clone_expr(self) -> PsBinOp: return type(self)(self._op1.clone(), self._op2.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -740,32 +834,47 @@ class PsLt(PsRel): class PsArrayInitList(PsExpression): + """N-dimensional array initialization matrix.""" + __match_args__ = ("items",) - def __init__(self, items: Sequence[PsExpression]): + def __init__( + self, + items: Sequence[PsExpression | Sequence[PsExpression | Sequence[PsExpression]]], + ): super().__init__() - self._items = list(items) + self._items = np.array(items, dtype=np.object_) @property - def items(self) -> list[PsExpression]: + def items_grid(self) -> NDArray[np.object_]: return self._items + @property + def shape(self) -> tuple[int, ...]: + return self._items.shape + + @property + def items(self) -> tuple[PsExpression, ...]: + return tuple(self._items.flat) # type: ignore + def get_children(self) -> tuple[PsAstNode, ...]: - return tuple(self._items) + return tuple(self._items.flat) # type: ignore def set_child(self, idx: int, c: PsAstNode): - self._items[idx] = failing_cast(PsExpression, c) + self._items.flat[idx] = failing_cast(PsExpression, c) - def clone(self) -> PsExpression: - return PsArrayInitList([expr.clone() for expr in self._items]) + def _clone_expr(self) -> PsExpression: + return PsArrayInitList( + np.array([expr.clone() for expr in self.children]).reshape( # type: ignore + self._items.shape + ) + ) def __repr__(self) -> str: return f"PsArrayInitList({repr(self._items)})" -def evaluate_expression( - expr: PsExpression, valuation: dict[str, Any] -) -> Any: +def evaluate_expression(expr: PsExpression, valuation: dict[str, Any]) -> Any: """Evaluate a pystencils backend expression tree with values assigned to symbols according to the given valuation. Only a subset of expression nodes can be processed by this evaluator. diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index cd3aae30d35061ab6c15c338a735aaecca83a141..3ae462c41c0170dcaa4a27adbd6d039df8c099d8 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -4,7 +4,7 @@ from types import NoneType from .astnode import PsAstNode, PsLeafMixIn from .expressions import PsExpression, PsLvalue, PsSymbolExpr -from ..symbols import PsSymbol +from ..memory import PsSymbol from .util import failing_cast @@ -320,7 +320,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``. Args: - text: The pragma's text, without the ``#pragma ``. + text: The pragma's text, without the ``#pragma``. """ __match_args__ = ("text",) diff --git a/src/pystencils/backend/ast/util.py b/src/pystencils/backend/ast/util.py index 0d3b78629fa9ee41d753893b1b6b4198cc75ae51..288097a901e3f11f4a6f12c47799b25ec672151e 100644 --- a/src/pystencils/backend/ast/util.py +++ b/src/pystencils/backend/ast/util.py @@ -1,8 +1,15 @@ from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, cast + +from ..exceptions import PsInternalCompilerError +from ..memory import PsSymbol +from ..memory import PsBuffer +from ...types import PsDereferencableType + if TYPE_CHECKING: from .astnode import PsAstNode + from .expressions import PsExpression def failing_cast(target: type | tuple[type, ...], obj: Any) -> Any: @@ -36,3 +43,43 @@ class AstEqWrapper: # TODO: consider replacing this with smth. more performant # TODO: Check that repr is implemented by all AST nodes return hash(repr(self._node)) + + +def determine_memory_object( + expr: PsExpression, +) -> tuple[PsSymbol | PsBuffer | None, bool]: + """Return the memory object accessed by the given expression, together with its constness + + Returns: + Tuple ``(mem_obj, const)`` identifying the memory object accessed by the given expression, + as well as its constness + """ + from pystencils.backend.ast.expressions import ( + PsSubscript, + PsLookup, + PsSymbolExpr, + PsMemAcc, + PsBufferAcc, + ) + + while isinstance(expr, (PsSubscript, PsLookup)): + match expr: + case PsSubscript(arr, _): + expr = arr + case PsLookup(record, _): + expr = record + + match expr: + case PsSymbolExpr(symb): + return symb, symb.get_dtype().const + case PsMemAcc(ptr, _): + return None, cast(PsDereferencableType, ptr.get_dtype()).base_type.const + case PsBufferAcc(ptr, _): + return ( + expr.buffer, + cast(PsDereferencableType, ptr.get_dtype()).base_type.const, + ) + case _: + raise PsInternalCompilerError( + "The given expression is a transient and does not refer to a memory object" + ) diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 8928cc6894284a417341f0b12a9e8d6a07f8ba48..6196d69bee44be11d48d2e3e18e731a730fb47d4 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -16,6 +16,7 @@ from .ast.structural import ( ) from .ast.expressions import ( + PsExpression, PsAdd, PsAddressOf, PsArrayInitList, @@ -26,7 +27,7 @@ from .ast.expressions import ( PsCall, PsCast, PsConstantExpr, - PsDeref, + PsMemAcc, PsDiv, PsRem, PsIntDiv, @@ -36,10 +37,9 @@ from .ast.expressions import ( PsNeg, PsRightShift, PsSub, - PsSubscript, PsSymbolExpr, PsLiteralExpr, - PsVectorArrayAccess, + PsVectorMemAcc, PsTernary, PsAnd, PsOr, @@ -50,11 +50,14 @@ from .ast.expressions import ( PsLt, PsGe, PsLe, + PsSubscript, + PsBufferAcc, ) from .extensions.foreign_ast import PsForeignExpression -from .symbols import PsSymbol +from .exceptions import PsInternalCompilerError +from .memory import PsSymbol from ..types import PsScalarType, PsArrayType from .kernelfunction import KernelFunction, GpuKernelFunction @@ -267,19 +270,30 @@ class CAstPrinter: case PsLiteralExpr(lit): return lit.text - case PsVectorArrayAccess(): + case PsVectorMemAcc(): raise EmissionError("Cannot print vectorized array accesses") - case PsSubscript(base, index): + case PsMemAcc(base, offset): pc.push_op(Ops.Subscript, LR.Left) base_code = self.visit(base, pc) pc.pop_op() pc.push_op(Ops.Weakest, LR.Middle) - index_code = self.visit(index, pc) + index_code = self.visit(offset, pc) pc.pop_op() return pc.parenthesize(f"{base_code}[{index_code}]", Ops.Subscript) + + case PsSubscript(base, indices): + pc.push_op(Ops.Subscript, LR.Left) + base_code = self.visit(base, pc) + pc.pop_op() + + pc.push_op(Ops.Weakest, LR.Middle) + indices_code = "".join("[" + self.visit(idx, pc) + "]" for idx in indices) + pc.pop_op() + + return pc.parenthesize(base_code + indices_code, Ops.Subscript) case PsLookup(aggr, member_name): pc.push_op(Ops.Lookup, LR.Left) @@ -320,12 +334,12 @@ class CAstPrinter: return pc.parenthesize(f"!{operand_code}", Ops.Not) - case PsDeref(operand): - pc.push_op(Ops.Deref, LR.Right) - operand_code = self.visit(operand, pc) - pc.pop_op() + # case PsDeref(operand): + # pc.push_op(Ops.Deref, LR.Right) + # operand_code = self.visit(operand, pc) + # pc.pop_op() - return pc.parenthesize(f"*{operand_code}", Ops.Deref) + # return pc.parenthesize(f"*{operand_code}", Ops.Deref) case PsAddressOf(operand): pc.push_op(Ops.AddressOf, LR.Right) @@ -355,17 +369,31 @@ class CAstPrinter: f"{cond_code} ? {then_code} : {else_code}", Ops.Ternary ) - case PsArrayInitList(items): + case PsArrayInitList(_): + def print_arr(item) -> str: + if isinstance(item, PsExpression): + return self.visit(item, pc) + else: + # it's a subarray + entries = ", ".join(print_arr(i) for i in item) + return "{ " + entries + " }" + pc.push_op(Ops.Weakest, LR.Middle) - items_str = ", ".join(self.visit(item, pc) for item in items) + arr_str = print_arr(node.items_grid) pc.pop_op() - return "{ " + items_str + " }" + return arr_str case PsForeignExpression(children): pc.push_op(Ops.Weakest, LR.Middle) foreign_code = node.get_code(self.visit(c, pc) for c in children) pc.pop_op() return foreign_code + + case PsBufferAcc(): + raise PsInternalCompilerError( + f"Unable to print C code for buffer access {node}.\n" + f"Buffer accesses must be lowered using the `LowerToC` pass before emission." + ) case _: raise NotImplementedError(f"Don't know how to print {node}") @@ -379,10 +407,11 @@ class CAstPrinter: def _symbol_decl(self, symb: PsSymbol): dtype = symb.get_dtype() - array_dims = [] - while isinstance(dtype, PsArrayType): - array_dims.append(dtype.length) + if isinstance(dtype, PsArrayType): + array_dims = dtype.shape dtype = dtype.base_type + else: + array_dims = () code = f"{dtype.c_string()} {symb.name}" for d in array_dims: diff --git a/src/pystencils/backend/extensions/cpp.py b/src/pystencils/backend/extensions/cpp.py index 1055b79e9ab197d62c4307b70ac5b2a71c13f139..025f4a3fb61d51d7fd9c485b597a671ae2cfc231 100644 --- a/src/pystencils/backend/extensions/cpp.py +++ b/src/pystencils/backend/extensions/cpp.py @@ -25,7 +25,7 @@ class CppMethodCall(PsForeignExpression): return super().structurally_equal(other) and self._method == other._method - def clone(self) -> CppMethodCall: + def _clone_expr(self) -> CppMethodCall: return CppMethodCall( cast(PsExpression, self.children[0]), self._method, diff --git a/src/pystencils/backend/jit/cpu_extension_module.py b/src/pystencils/backend/jit/cpu_extension_module.py index b9b79358908686ce7ce5ab412d89b64948cdcd3f..d7f64455082523a0d189e9864080cb62827ea156 100644 --- a/src/pystencils/backend/jit/cpu_extension_module.py +++ b/src/pystencils/backend/jit/cpu_extension_module.py @@ -13,11 +13,8 @@ from ..exceptions import PsInternalCompilerError from ..kernelfunction import ( KernelFunction, KernelParameter, - FieldParameter, - FieldShapeParam, - FieldStrideParam, - FieldPointerParam, ) +from ..properties import FieldBasePtr, FieldShape, FieldStride from ..constraints import KernelParamsConstraint from ...types import ( PsType, @@ -209,7 +206,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ self._array_extractions: dict[Field, str] = dict() self._array_frees: dict[Field, str] = dict() - self._array_assoc_var_extractions: dict[FieldParameter, str] = dict() + self._array_assoc_var_extractions: dict[KernelParameter, str] = dict() self._scalar_extractions: dict[KernelParameter, str] = dict() self._constraint_checks: list[str] = [] @@ -282,31 +279,34 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return param.name - def extract_array_assoc_var(self, param: FieldParameter) -> str: + def extract_array_assoc_var(self, param: KernelParameter) -> str: if param not in self._array_assoc_var_extractions: - field = param.field + field = param.fields[0] buffer = self.extract_field(field) - match param: - case FieldPointerParam(): - code = f"{param.dtype} {param.name} = ({param.dtype}) {buffer}.buf;" - case FieldShapeParam(): - coord = param.coordinate - code = f"{param.dtype} {param.name} = {buffer}.shape[{coord}];" - case FieldStrideParam(): - coord = param.coordinate - code = ( - f"{param.dtype} {param.name} = " - f"{buffer}.strides[{coord}] / {field.dtype.itemsize};" - ) - case _: - assert False, "unreachable code" + code: str | None = None + + for prop in param.properties: + match prop: + case FieldBasePtr(): + code = f"{param.dtype} {param.name} = ({param.dtype}) {buffer}.buf;" + break + case FieldShape(_, coord): + code = f"{param.dtype} {param.name} = {buffer}.shape[{coord}];" + break + case FieldStride(_, coord): + code = ( + f"{param.dtype} {param.name} = " + f"{buffer}.strides[{coord}] / {field.dtype.itemsize};" + ) + break + assert code is not None self._array_assoc_var_extractions[param] = code return param.name def extract_parameter(self, param: KernelParameter): - if isinstance(param, FieldParameter): + if param.is_field_parameter: self.extract_array_assoc_var(param) else: self.extract_scalar(param) diff --git a/src/pystencils/backend/jit/gpu_cupy.py b/src/pystencils/backend/jit/gpu_cupy.py index d6aaac2d2ee26901b73ae9d2b6daaba06c055bd0..7f38d9d434333c0d504babe5ed1fea65f6f85dad 100644 --- a/src/pystencils/backend/jit/gpu_cupy.py +++ b/src/pystencils/backend/jit/gpu_cupy.py @@ -16,11 +16,9 @@ from .jit import JitBase, JitError, KernelWrapper from ..kernelfunction import ( KernelFunction, GpuKernelFunction, - FieldPointerParam, - FieldShapeParam, - FieldStrideParam, KernelParameter, ) +from ..properties import FieldShape, FieldStride, FieldBasePtr from ..emission import emit_code from ...types import PsStructType @@ -98,8 +96,8 @@ class CupyKernelWrapper(KernelWrapper): field_shapes = set() index_shapes = set() - def check_shape(field_ptr: FieldPointerParam, arr: cp.ndarray): - field = field_ptr.field + def check_shape(field_ptr: KernelParameter, arr: cp.ndarray): + field = field_ptr.fields[0] if field.has_fixed_shape: expected_shape = tuple(int(s) for s in field.shape) @@ -118,7 +116,7 @@ class CupyKernelWrapper(KernelWrapper): if isinstance(field.dtype, PsStructType): assert expected_strides[-1] == 1 expected_strides = expected_strides[:-1] - + actual_strides = tuple(s // arr.dtype.itemsize for s in arr.strides) if expected_strides != actual_strides: raise ValueError( @@ -149,28 +147,38 @@ class CupyKernelWrapper(KernelWrapper): arr: cp.ndarray for kparam in self._kfunc.parameters: - match kparam: - case FieldPointerParam(_, dtype, field): - arr = kwargs[field.name] - if arr.dtype != field.dtype.numpy_dtype: - raise JitError( - f"Data type mismatch at array argument {field.name}:" - f"Expected {field.dtype}, got {arr.dtype}" - ) - check_shape(kparam, arr) - args.append(arr) - - case FieldShapeParam(name, dtype, field, coord): - arr = kwargs[field.name] - add_arg(name, arr.shape[coord], dtype) - - case FieldStrideParam(name, dtype, field, coord): - arr = kwargs[field.name] - add_arg(name, arr.strides[coord] // arr.dtype.itemsize, dtype) - - case KernelParameter(name, dtype): - val: Any = kwargs[name] - add_arg(name, val, dtype) + if kparam.is_field_parameter: + # Determine field-associated data to pass in + for prop in kparam.properties: + match prop: + case FieldBasePtr(field): + arr = kwargs[field.name] + if arr.dtype != field.dtype.numpy_dtype: + raise JitError( + f"Data type mismatch at array argument {field.name}:" + f"Expected {field.dtype}, got {arr.dtype}" + ) + check_shape(kparam, arr) + args.append(arr) + break + + case FieldShape(field, coord): + arr = kwargs[field.name] + add_arg(kparam.name, arr.shape[coord], kparam.dtype) + break + + case FieldStride(field, coord): + arr = kwargs[field.name] + add_arg( + kparam.name, + arr.strides[coord] // arr.dtype.itemsize, + kparam.dtype, + ) + break + else: + # scalar parameter + val: Any = kwargs[kparam.name] + add_arg(kparam.name, val, kparam.dtype) # Determine launch grid from ..ast.expressions import evaluate_expression diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index a0328a123893a9f103c0ef66aa98028cc5437708..2462e5e66ea1a55cd638df07f645b213dd37d68f 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -8,11 +8,11 @@ from ..ast import PsAstNode from ..ast.expressions import PsExpression, PsSymbolExpr, PsConstantExpr from ..ast.structural import PsLoop, PsBlock, PsAssignment -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..constants import PsConstant from .context import KernelCreationContext -from .freeze import FreezeExpressions +from .freeze import FreezeExpressions, ExprLike from .typification import Typifier from .iteration_space import FullIterationSpace @@ -42,14 +42,14 @@ class AstFactory: pass @overload - def parse_sympy(self, sp_obj: sp.Expr) -> PsExpression: + def parse_sympy(self, sp_obj: ExprLike) -> PsExpression: pass @overload def parse_sympy(self, sp_obj: AssignmentBase) -> PsAssignment: pass - def parse_sympy(self, sp_obj: sp.Expr | AssignmentBase) -> PsAstNode: + def parse_sympy(self, sp_obj: ExprLike | AssignmentBase) -> PsAstNode: """Parse a SymPy expression or assignment through `FreezeExpressions` and `Typifier`. The expression or assignment will be typified in a numerical context, using the kernel @@ -170,6 +170,8 @@ class AstFactory: raise ValueError( "Cannot parse a slice with `stop == None` if no normalization limit is given" ) + + assert stop is not None # for mypy return start, stop, step diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 73e3c70cc315c85c9ec01d2db401d10eaa70c53f..839b8fd9829a83b46dbe2419013959b42943b96c 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -9,14 +9,14 @@ from ...defaults import DEFAULTS from ...field import Field, FieldType from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType -from ..symbols import PsSymbol -from ..arrays import PsLinearizedArray +from ..memory import PsSymbol, PsBuffer +from ..properties import FieldShape, FieldStride +from ..constants import PsConstant from ...types import ( PsType, PsIntegerType, PsNumericType, - PsScalarType, - PsStructType, + PsPointerType, deconstify, ) from ..constraints import KernelParamsConstraint @@ -221,64 +221,14 @@ class KernelCreationContext: else: return - arr_shape: list[str | int] | None = None - arr_strides: list[str | int] | None = None - - def normalize_type(s: TypedSymbol) -> PsIntegerType: - match s.dtype: - case DynamicType.INDEX_TYPE: - return self.index_dtype - case DynamicType.NUMERIC_TYPE: - if isinstance(self.default_dtype, PsIntegerType): - return self.default_dtype - else: - raise KernelConstraintsError( - f"Cannot use non-integer default numeric type {self.default_dtype} " - f"in field indexing symbol {s}." - ) - case PsIntegerType(): - return deconstify(s.dtype) - case _: - raise KernelConstraintsError( - f"Invalid data type for field indexing symbol {s}: {s.dtype}" - ) - - # Check field constraints and add to collection + # Check field constraints, create buffer, and add them to the collection match field.field_type: case FieldType.GENERIC | FieldType.STAGGERED | FieldType.STAGGERED_FLUX: + buf = self._create_regular_field_buffer(field) self._fields_collection.domain_fields.add(field) case FieldType.BUFFER: - if field.spatial_dimensions != 1: - raise KernelConstraintsError( - f"Invalid spatial shape of buffer field {field.name}: {field.spatial_dimensions}. " - "Buffer fields must be one-dimensional." - ) - - if field.index_dimensions > 1: - raise KernelConstraintsError( - f"Invalid index shape of buffer field {field.name}: {field.spatial_dimensions}. " - "Buffer fields can have at most one index dimension." - ) - - num_entries = field.index_shape[0] if field.index_shape else 1 - if not isinstance(num_entries, int): - raise KernelConstraintsError( - f"Invalid index shape of buffer field {field.name}: {num_entries}. " - "Buffer fields cannot have variable index shape." - ) - - buffer_len = field.spatial_shape[0] - - if isinstance(buffer_len, TypedSymbol): - idx_type = normalize_type(buffer_len) - arr_shape = [buffer_len.name, num_entries] - else: - idx_type = DEFAULTS.index_dtype - arr_shape = [buffer_len, num_entries] - - arr_strides = [num_entries, 1] - + buf = self._create_buffer_field_buffer(field) self._fields_collection.buffer_fields.add(field) case FieldType.INDEXED: @@ -287,67 +237,24 @@ class KernelCreationContext: f"Invalid spatial shape of index field {field.name}: {field.spatial_dimensions}. " "Index fields must be one-dimensional." ) + buf = self._create_regular_field_buffer(field) self._fields_collection.index_fields.add(field) case FieldType.CUSTOM: + buf = self._create_regular_field_buffer(field) self._fields_collection.custom_fields.add(field) case _: assert False, "unreachable code" - # For non-buffer fields, determine shape and strides - - if arr_shape is None: - idx_types = set( - normalize_type(s) - for s in chain(field.shape, field.strides) - if isinstance(s, TypedSymbol) - ) - - if len(idx_types) > 1: - raise KernelConstraintsError( - f"Multiple incompatible types found in index symbols of field {field}: " - f"{idx_types}" - ) - idx_type = idx_types.pop() if len(idx_types) > 0 else self.index_dtype - - arr_shape = [ - (s.name if isinstance(s, TypedSymbol) else s) for s in field.shape - ] - - arr_strides = [ - (s.name if isinstance(s, TypedSymbol) else s) for s in field.strides - ] - - # The frontend doesn't quite agree with itself on how to model - # fields with trivial index dimensions. Sometimes the index_shape is empty, - # sometimes its (1,). This is canonicalized here. - if not field.index_shape: - arr_shape += [1] - arr_strides += [1] - - # Add array - assert arr_strides is not None - assert idx_type is not None - - assert isinstance(field.dtype, (PsScalarType, PsStructType)) - element_type = field.dtype - - arr = PsLinearizedArray( - field.name, element_type, arr_shape, arr_strides, idx_type - ) - - self._fields_and_arrays[field.name] = FieldArrayPair(field, arr) - for symb in chain([arr.base_pointer], arr.shape, arr.strides): - if isinstance(symb, PsSymbol): - self.add_symbol(symb) + self._fields_and_arrays[field.name] = FieldArrayPair(field, buf) @property - def arrays(self) -> Iterable[PsLinearizedArray]: + def arrays(self) -> Iterable[PsBuffer]: # return self._fields_and_arrays.values() yield from (item.array for item in self._fields_and_arrays.values()) - def get_array(self, field: Field) -> PsLinearizedArray: + def get_buffer(self, field: Field) -> PsBuffer: """Retrieve the underlying array for a given field. If the given field was not previously registered using `add_field`, @@ -393,3 +300,114 @@ class KernelCreationContext: def require_header(self, header: str): self._req_headers.add(header) + + # ----------- Internals --------------------------------------------------------------------- + + def _normalize_type(self, s: TypedSymbol) -> PsIntegerType: + match s.dtype: + case DynamicType.INDEX_TYPE: + return self.index_dtype + case DynamicType.NUMERIC_TYPE: + if isinstance(self.default_dtype, PsIntegerType): + return self.default_dtype + else: + raise KernelConstraintsError( + f"Cannot use non-integer default numeric type {self.default_dtype} " + f"in field indexing symbol {s}." + ) + case PsIntegerType(): + return deconstify(s.dtype) + case _: + raise KernelConstraintsError( + f"Invalid data type for field indexing symbol {s}: {s.dtype}" + ) + + def _create_regular_field_buffer(self, field: Field) -> PsBuffer: + idx_types = set( + self._normalize_type(s) + for s in chain(field.shape, field.strides) + if isinstance(s, TypedSymbol) + ) + + if len(idx_types) > 1: + raise KernelConstraintsError( + f"Multiple incompatible types found in index symbols of field {field}: " + f"{idx_types}" + ) + + idx_type = idx_types.pop() if len(idx_types) > 0 else self.index_dtype + + def convert_size(s: TypedSymbol | int) -> PsSymbol | PsConstant: + if isinstance(s, TypedSymbol): + return self.get_symbol(s.name, idx_type) + else: + return PsConstant(s, idx_type) + + buf_shape = [convert_size(s) for s in field.shape] + buf_strides = [convert_size(s) for s in field.strides] + + # The frontend doesn't quite agree with itself on how to model + # fields with trivial index dimensions. Sometimes the index_shape is empty, + # sometimes its (1,). This is canonicalized here. + if not field.index_shape: + buf_shape += [convert_size(1)] + buf_strides += [convert_size(1)] + + for i, size in enumerate(buf_shape): + if isinstance(size, PsSymbol): + size.add_property(FieldShape(field, i)) + + for i, stride in enumerate(buf_strides): + if isinstance(stride, PsSymbol): + stride.add_property(FieldStride(field, i)) + + base_ptr = self.get_symbol( + DEFAULTS.field_pointer_name(field.name), + PsPointerType(field.dtype, restrict=True), + ) + + return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides) + + def _create_buffer_field_buffer(self, field: Field) -> PsBuffer: + if field.spatial_dimensions != 1: + raise KernelConstraintsError( + f"Invalid spatial shape of buffer field {field.name}: {field.spatial_dimensions}. " + "Buffer fields must be one-dimensional." + ) + + if field.index_dimensions > 1: + raise KernelConstraintsError( + f"Invalid index shape of buffer field {field.name}: {field.spatial_dimensions}. " + "Buffer fields can have at most one index dimension." + ) + + num_entries = field.index_shape[0] if field.index_shape else 1 + if not isinstance(num_entries, int): + raise KernelConstraintsError( + f"Invalid index shape of buffer field {field.name}: {num_entries}. " + "Buffer fields cannot have variable index shape." + ) + + buffer_len = field.spatial_shape[0] + buf_shape: list[PsSymbol | PsConstant] + + if isinstance(buffer_len, TypedSymbol): + idx_type = self._normalize_type(buffer_len) + len_symb = self.get_symbol(buffer_len.name, idx_type) + len_symb.add_property(FieldShape(field, 0)) + buf_shape = [len_symb, PsConstant(num_entries, idx_type)] + else: + idx_type = DEFAULTS.index_dtype + buf_shape = [ + PsConstant(buffer_len, idx_type), + PsConstant(num_entries, idx_type), + ] + + buf_strides = [PsConstant(num_entries, idx_type), PsConstant(1, idx_type)] + + base_ptr = self.get_symbol( + DEFAULTS.field_pointer_name(field.name), + PsPointerType(field.dtype, restrict=True), + ) + + return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index b5c04f1bd029ac8fa62d2efaf3f0983b7cb858b5..bdc8f11336886dade665e475a023ac49a7595eba 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -14,7 +14,7 @@ from ...sympyextensions import ( ConditionalFieldAccess, ) from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType -from ...sympyextensions.pointers import AddressOf +from ...sympyextensions.pointers import AddressOf, mem_acc from ...field import Field, FieldType from .context import KernelCreationContext @@ -28,7 +28,7 @@ from ..ast.structural import ( PsSymbolExpr, ) from ..ast.expressions import ( - PsArrayAccess, + PsBufferAcc, PsArrayInitList, PsBitwiseAnd, PsBitwiseOr, @@ -43,7 +43,7 @@ from ..ast.expressions import ( PsLookup, PsRightShift, PsSubscript, - PsVectorArrayAccess, + PsVectorMemAcc, PsTernary, PsRel, PsEq, @@ -55,10 +55,11 @@ from ..ast.expressions import ( PsAnd, PsOr, PsNot, + PsMemAcc ) from ..constants import PsConstant -from ...types import PsStructType, PsType +from ...types import PsNumericType, PsStructType, PsType from ..exceptions import PsInputError from ..functions import PsMathFunction, MathFunctions @@ -157,7 +158,7 @@ class FreezeExpressions: if isinstance(lhs, PsSymbolExpr): return PsDeclaration(lhs, rhs) - elif isinstance(lhs, (PsArrayAccess, PsLookup, PsVectorArrayAccess)): # todo + elif isinstance(lhs, (PsBufferAcc, PsLookup, PsVectorMemAcc)): # todo return PsAssignment(lhs, rhs) else: raise FreezeError( @@ -191,27 +192,19 @@ class FreezeExpressions: def map_Add(self, expr: sp.Add) -> PsExpression: # TODO: think about numerically sensible ways of freezing sums and products - signs: list[int] = [] - for summand in expr.args: - if summand.is_negative: - signs.append(-1) - elif isinstance(summand, sp.Mul) and any( - factor.is_negative for factor in summand.args - ): - signs.append(-1) - else: - signs.append(1) frozen_expr = self.visit_expr(expr.args[0]) - for sign, arg in zip(signs[1:], expr.args[1:]): - if sign == -1: - arg = -arg + for summand in expr.args[1:]: + if isinstance(summand, sp.Mul) and any( + factor == -1 for factor in summand.args + ): + summand = -summand op = sub else: op = add - frozen_expr = op(frozen_expr, self.visit_expr(arg)) + frozen_expr = op(frozen_expr, self.visit_expr(summand)) return frozen_expr @@ -272,7 +265,7 @@ class FreezeExpressions: def map_TypedSymbol(self, expr: TypedSymbol): dtype = expr.dtype - + match dtype: case DynamicType.NUMERIC_TYPE: dtype = self._ctx.default_dtype @@ -283,20 +276,40 @@ class FreezeExpressions: return PsSymbolExpr(symb) def map_Tuple(self, expr: sp.Tuple) -> PsArrayInitList: + if not expr: + raise FreezeError("Cannot translate an empty tuple.") + items = [self.visit_expr(item) for item in expr] - return PsArrayInitList(items) + + if any(isinstance(i, PsArrayInitList) for i in items): + # base case: have nested arrays + if not all(isinstance(i, PsArrayInitList) for i in items): + raise FreezeError( + f"Cannot translate nested arrays of non-uniform shape: {expr}" + ) + + subarrays = cast(list[PsArrayInitList], items) + shape_tail = subarrays[0].shape + + if not all(s.shape == shape_tail for s in subarrays[1:]): + raise FreezeError( + f"Cannot translate nested arrays of non-uniform shape: {expr}" + ) + + return PsArrayInitList([s.items_grid for s in subarrays]) # type: ignore + else: + # base case: no nested arrays + return PsArrayInitList(items) def map_Indexed(self, expr: sp.Indexed) -> PsSubscript: assert isinstance(expr.base, sp.IndexedBase) base = self.visit_expr(expr.base.label) - subscript = PsSubscript(base, self.visit_expr(expr.indices[0])) - for idx in expr.indices[1:]: - subscript = PsSubscript(subscript, self.visit_expr(idx)) - return subscript + indices = [self.visit_expr(i) for i in expr.indices] + return PsSubscript(base, indices) def map_Access(self, access: Field.Access): field = access.field - array = self._ctx.get_array(field) + array = self._ctx.get_buffer(field) ptr = array.base_pointer offsets: list[PsExpression] = [ @@ -350,18 +363,11 @@ class FreezeExpressions: # For canonical representation, there must always be at least one index dimension indices = [PsExpression.make(PsConstant(0))] - summands = tuple( - idx * PsExpression.make(stride) - for idx, stride in zip(offsets + indices, array.strides, strict=True) - ) - - index = summands[0] if len(summands) == 1 else reduce(add, summands) - if struct_member_name is not None: # Produce a Lookup here, don't check yet if the member name is valid. That's the typifier's job. - return PsLookup(PsArrayAccess(ptr, index), struct_member_name) + return PsLookup(PsBufferAcc(ptr, offsets + indices), struct_member_name) else: - return PsArrayAccess(ptr, index) + return PsBufferAcc(ptr, offsets + indices) def map_ConditionalFieldAccess(self, acc: ConditionalFieldAccess): facc = self.visit_expr(acc.access) @@ -422,14 +428,23 @@ class FreezeExpressions: return PsBitwiseOr(*args) case integer_functions.int_power_of_2(): return PsLeftShift(PsExpression.make(PsConstant(1)), args[0]) - # TODO: what exactly are the semantics? - # case integer_functions.modulo_floor(): - # case integer_functions.div_floor() - # TODO: requires if *expression* - # case integer_functions.modulo_ceil(): - # case integer_functions.div_ceil(): + case integer_functions.round_to_multiple_towards_zero(): + return PsIntDiv(args[0], args[1]) * args[1] + case integer_functions.ceil_to_multiple(): + return ( + PsIntDiv( + args[0] + args[1] - PsExpression.make(PsConstant(1)), args[1] + ) + * args[1] + ) + case integer_functions.div_ceil(): + return PsIntDiv( + args[0] + args[1] - PsExpression.make(PsConstant(1)), args[1] + ) case AddressOf(): return PsAddressOf(*args) + case mem_acc(): + return PsMemAcc(*args) case _: raise FreezeError(f"Unsupported function: {func}") @@ -469,7 +484,7 @@ class FreezeExpressions: ] return cast(PsCall, args[0]) - def map_CastFunc(self, cast_expr: CastFunc) -> PsCast: + def map_CastFunc(self, cast_expr: CastFunc) -> PsCast | PsConstantExpr: dtype: PsType match cast_expr.dtype: case DynamicType.NUMERIC_TYPE: @@ -479,7 +494,19 @@ class FreezeExpressions: case other if isinstance(other, PsType): dtype = other - return PsCast(dtype, self.visit_expr(cast_expr.expr)) + arg = self.visit_expr(cast_expr.expr) + if ( + isinstance(arg, PsConstantExpr) + and arg.constant.dtype is None + and isinstance(dtype, PsNumericType) + ): + # As of now, the typifier can not infer the type of a bare constant. + # However, untyped constants may not appear in ASTs from which + # kernel functions are generated. Therefore, we annotate constants + # instead of casting them. + return PsConstantExpr(arg.constant.interpret_as(dtype)) + else: + return PsCast(dtype, arg) def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel: arg1, arg2 = [self.visit_expr(arg) for arg in rel.args] diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 8175fffed9782d9271262b7bc220e4dcdc208705..bae0328e4348836463f2d6831fc1905855354548 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -9,10 +9,9 @@ from ...defaults import DEFAULTS from ...simp import AssignmentCollection from ...field import Field, FieldType -from ..symbols import PsSymbol +from ..memory import PsSymbol, PsBuffer from ..constants import PsConstant from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem -from ..arrays import PsLinearizedArray from ..ast.util import failing_cast from ...types import PsStructType, constify from ..exceptions import PsInputError, KernelConstraintsError @@ -74,7 +73,7 @@ class FullIterationSpace(IterationSpace): ) -> FullIterationSpace: """Create an iteration space over an archetype field with ghost layers.""" - archetype_array = ctx.get_array(archetype_field) + archetype_array = ctx.get_buffer(archetype_field) dim = archetype_field.spatial_dimensions counters = [ @@ -142,7 +141,7 @@ class FullIterationSpace(IterationSpace): archetype_size: tuple[PsSymbol | PsConstant | None, ...] if archetype_field is not None: - archetype_array = ctx.get_array(archetype_field) + archetype_array = ctx.get_buffer(archetype_field) if archetype_field.spatial_dimensions != dim: raise ValueError( @@ -281,7 +280,7 @@ class SparseIterationSpace(IterationSpace): def __init__( self, spatial_indices: Sequence[PsSymbol], - index_list: PsLinearizedArray, + index_list: PsBuffer, coordinate_members: Sequence[PsStructType.Member], sparse_counter: PsSymbol, ): @@ -291,7 +290,7 @@ class SparseIterationSpace(IterationSpace): self._sparse_counter = sparse_counter @property - def index_list(self) -> PsLinearizedArray: + def index_list(self) -> PsBuffer: return self._index_list @property @@ -365,7 +364,7 @@ def create_sparse_iteration_space( # Determine index field if index_field is not None: - idx_arr = ctx.get_array(index_field) + idx_arr = ctx.get_buffer(index_field) idx_struct_type: PsStructType = failing_cast(PsStructType, idx_arr.element_type) for coord in coord_members: diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index fc085e2be99f61204cde92438811a3b4e41c8bf7..c8fad68f106dea78126be9d9ada51e2c57180cd2 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TypeVar +from typing import TypeVar, Callable from .context import KernelCreationContext from ...types import ( @@ -23,10 +23,11 @@ from ..ast.structural import ( PsExpression, PsAssignment, PsDeclaration, + PsStatement, PsEmptyLeafMixIn, ) from ..ast.expressions import ( - PsArrayAccess, + PsBufferAcc, PsArrayInitList, PsBinOp, PsIntOpTrait, @@ -35,11 +36,11 @@ from ..ast.expressions import ( PsCall, PsTernary, PsCast, - PsDeref, PsAddressOf, PsConstantExpr, PsLookup, PsSubscript, + PsMemAcc, PsSymbolExpr, PsLiteralExpr, PsRel, @@ -47,6 +48,7 @@ from ..ast.expressions import ( PsNot, ) from ..functions import PsMathFunction, CFunction +from ..ast.util import determine_memory_object __all__ = ["Typifier"] @@ -57,38 +59,40 @@ class TypificationError(Exception): NodeT = TypeVar("NodeT", bound=PsAstNode) +ResolutionHook = Callable[[PsType], None] + class TypeContext: """Typing context, with support for type inference and checking. Instances of this class are used to propagate and check data types across expression subtrees - of the AST. Each type context has: - - - A target type `target_type`, which shall be applied to all expressions it covers - - A set of restrictions on the target type: - - `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides - - Additional restrictions may be added in the future. + of the AST. Each type context has a target type `target_type`, which shall be applied to all expressions it covers """ def __init__( self, target_type: PsType | None = None, - require_nonconst: bool = False, ): - self._require_nonconst = require_nonconst self._deferred_exprs: list[PsExpression] = [] - self._target_type = ( - self._fix_constness(target_type) if target_type is not None else None - ) + self._target_type = deconstify(target_type) if target_type is not None else None + + self._hooks: list[ResolutionHook] = [] @property def target_type(self) -> PsType | None: return self._target_type - @property - def require_nonconst(self) -> bool: - return self._require_nonconst + def add_hook(self, hook: ResolutionHook): + """Adds a resolution hook to this context. + + The hook will be called with the context's target type as soon as it becomes known, + which might be immediately. + """ + if self._target_type is None: + self._hooks.append(hook) + else: + hook(self._target_type) def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None): """Applies the given ``dtype`` to this type context, and optionally to the given expression. @@ -102,7 +106,7 @@ class TypeContext: and will be replaced by it. """ - dtype = self._fix_constness(dtype) + dtype = deconstify(dtype) if self._target_type is not None and dtype != self._target_type: raise TypificationError( @@ -134,6 +138,12 @@ class TypeContext: self._apply_target_type(expr) def _propagate_target_type(self): + assert self._target_type is not None + + for hook in self._hooks: + hook(self._target_type) + self._hooks = [] + for expr in self._deferred_exprs: self._apply_target_type(expr) self._deferred_exprs = [] @@ -211,30 +221,10 @@ class TypeContext: def _compatible(self, dtype: PsType): """Checks whether the given data type is compatible with the context's target type. - - If the target type is ``const``, they must be equal up to const qualification; - if the target type is not ``const``, `dtype` must match it exactly. + The two must match except for constness. """ assert self._target_type is not None - if self._target_type.const: - return constify(dtype) == self._target_type - else: - return dtype == self._target_type - - def _fix_constness(self, dtype: PsType, expr: PsExpression | None = None): - if self._require_nonconst: - if dtype.const: - if expr is None: - raise TypificationError( - f"Type mismatch: Encountered {dtype} in non-constant context." - ) - else: - raise TypificationError( - f"Type mismatch at expression {expr}: Encountered {dtype} in non-constant context." - ) - return dtype - else: - return constify(dtype) + return deconstify(dtype) == self._target_type class Typifier: @@ -276,11 +266,7 @@ class Typifier: **Typing of symbol expressions** - Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types, but - not necessarily their const-qualification. - A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type, - and an array base pointer with non-``const`` base type may be nested in a ``const`` `PsArrayAccess`, - but not vice versa. + Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types. """ def __init__(self, ctx: KernelCreationContext): @@ -316,7 +302,52 @@ class Typifier: for s in statements: self.visit(s) - case PsDeclaration(lhs, rhs): + case PsStatement(expr): + tc = TypeContext() + self.visit_expr(expr, tc) + if tc.target_type is None: + tc.apply_dtype(self._ctx.default_dtype) + + case PsDeclaration(lhs, rhs) if isinstance(rhs, PsArrayInitList): + # Special treatment for array declarations + assert isinstance(lhs, PsSymbolExpr) + + decl_tc = TypeContext() + items_tc = TypeContext() + + if (lhs_type := lhs.symbol.dtype) is not None: + if not isinstance(lhs_type, PsArrayType): + raise TypificationError( + f"Illegal LHS type in array declaration: {lhs_type}" + ) + + if lhs_type.shape != rhs.shape: + raise TypificationError( + f"Incompatible shapes in declaration of array symbol {lhs.symbol}.\n" + f" Symbol shape: {lhs_type.shape}\n" + f" Array shape: {rhs.shape}" + ) + + items_tc.apply_dtype(lhs_type.base_type) + decl_tc.apply_dtype(lhs_type, lhs) + else: + decl_tc.infer_dtype(lhs) + + for item in rhs.items: + self.visit_expr(item, items_tc) + + if items_tc.target_type is None: + items_tc.apply_dtype(self._ctx.default_dtype) + + if decl_tc.target_type is None: + assert items_tc.target_type is not None + decl_tc.apply_dtype( + PsArrayType(items_tc.target_type, rhs.shape), rhs + ) + else: + decl_tc.infer_dtype(rhs) + + case PsDeclaration(lhs, rhs) | PsAssignment(lhs, rhs): # Only if the LHS is an untyped symbol, infer its type from the RHS infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None @@ -333,27 +364,11 @@ class Typifier: if infer_lhs and tc.target_type is None: # no type has been inferred -> use the default dtype tc.apply_dtype(self._ctx.default_dtype) - - case PsAssignment(lhs, rhs): - infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None - - tc_lhs = TypeContext(require_nonconst=True) - - if infer_lhs: - tc_lhs.infer_dtype(lhs) - else: - self.visit_expr(lhs, tc_lhs) - assert tc_lhs.target_type is not None - - tc_rhs = TypeContext(target_type=tc_lhs.target_type) - self.visit_expr(rhs, tc_rhs) - - if infer_lhs: - if tc_rhs.target_type is None: - tc_rhs.apply_dtype(self._ctx.default_dtype) - - assert tc_rhs.target_type is not None - tc_lhs.apply_dtype(deconstify(tc_rhs.target_type)) + elif not isinstance(node, PsDeclaration): + # check mutability of LHS + _, lhs_const = determine_memory_object(lhs) + if lhs_const: + raise TypificationError(f"Cannot assign to immutable LHS {lhs}") case PsConditional(cond, branch_true, branch_false): cond_tc = TypeContext(PsBoolType()) @@ -407,51 +422,59 @@ class Typifier: case PsLiteralExpr(lit): tc.apply_dtype(lit.dtype, expr) - case PsArrayAccess(bptr, idx): - tc.apply_dtype(bptr.array.element_type, expr) + case PsBufferAcc(_, indices): + tc.apply_dtype(expr.buffer.element_type, expr) + for idx in indices: + self._handle_idx(idx) - index_tc = TypeContext() - self.visit_expr(idx, index_tc) - if index_tc.target_type is None: - index_tc.apply_dtype(self._ctx.index_dtype, idx) - elif not isinstance(index_tc.target_type, PsIntegerType): + case PsMemAcc(ptr, offset): + ptr_tc = TypeContext() + self.visit_expr(ptr, ptr_tc) + + if not isinstance(ptr_tc.target_type, PsPointerType): raise TypificationError( - f"Array index is not of integer type: {idx} has type {index_tc.target_type}" + f"Type of pointer argument to memory access was not a pointer type: {ptr_tc.target_type}" ) - case PsSubscript(arr, idx): - arr_tc = TypeContext() - self.visit_expr(arr, arr_tc) + tc.apply_dtype(ptr_tc.target_type.base_type, expr) + self._handle_idx(offset) - if not isinstance(arr_tc.target_type, PsDereferencableType): - raise TypificationError( - "Type of subscript base is not subscriptable." - ) + case PsSubscript(arr, indices): + if isinstance(arr, PsArrayInitList): + shape = arr.shape - tc.apply_dtype(arr_tc.target_type.base_type, expr) + # extend outer context over the init-list entries + for item in arr.items: + self.visit_expr(item, tc) - index_tc = TypeContext() - self.visit_expr(idx, index_tc) - if index_tc.target_type is None: - index_tc.apply_dtype(self._ctx.index_dtype, idx) - elif not isinstance(index_tc.target_type, PsIntegerType): - raise TypificationError( - f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}" - ) + # learn the array type from the items + def arr_hook(element_type: PsType): + arr.dtype = PsArrayType(element_type, arr.shape) - case PsDeref(ptr): - ptr_tc = TypeContext() - self.visit_expr(ptr, ptr_tc) + tc.add_hook(arr_hook) + else: + # type of array has to be known + arr_tc = TypeContext() + self.visit_expr(arr, arr_tc) - if not isinstance(ptr_tc.target_type, PsDereferencableType): + if not isinstance(arr_tc.target_type, PsArrayType): + raise TypificationError( + f"Type of array argument to subscript was not an array type: {arr_tc.target_type}" + ) + + tc.apply_dtype(arr_tc.target_type.base_type, expr) + shape = arr_tc.target_type.shape + + if len(indices) != len(shape): raise TypificationError( - "Type of argument to a Deref is not dereferencable" + f"Invalid number of indices to {len(shape)}-dimensional array: {len(indices)}" ) - tc.apply_dtype(ptr_tc.target_type.base_type, expr) + for idx in indices: + self._handle_idx(idx) case PsAddressOf(arg): - if not isinstance(arg, (PsSymbolExpr, PsSubscript, PsDeref, PsLookup)): + if not isinstance(arg, (PsSymbolExpr, PsSubscript, PsMemAcc, PsBufferAcc, PsLookup)): raise TypificationError( f"Illegal expression below AddressOf operator: {arg}" ) @@ -468,8 +491,8 @@ class Typifier: match arg: case PsSymbolExpr(s): pointed_to_type = s.get_dtype() - case PsSubscript(arr, _) | PsDeref(arr): - arr_type = arr.get_dtype() + case PsSubscript(ptr, _) | PsMemAcc(ptr, _) | PsBufferAcc(ptr, _): + arr_type = ptr.get_dtype() assert isinstance(arr_type, PsDereferencableType) pointed_to_type = arr_type.base_type case PsLookup(aggr, member_name): @@ -491,7 +514,7 @@ class Typifier: case PsLookup(aggr, member_name): # Members of a struct type inherit the struct type's `const` qualifier - aggr_tc = TypeContext(None, require_nonconst=tc.require_nonconst) + aggr_tc = TypeContext() self.visit_expr(aggr, aggr_tc) aggr_type = aggr_tc.target_type @@ -566,36 +589,35 @@ class Typifier: f"Don't know how to typify calls to {function}" ) - case PsArrayInitList(items): - items_tc = TypeContext() - for item in items: - self.visit_expr(item, items_tc) - - if items_tc.target_type is None: - if tc.target_type is None: - raise TypificationError(f"Unable to infer type of array {expr}") - elif not isinstance(tc.target_type, PsArrayType): - raise TypificationError( - f"Cannot apply type {tc.target_type} to an array initializer." - ) - elif ( - tc.target_type.length is not None - and tc.target_type.length != len(items) - ): - raise TypificationError( - "Array size mismatch: Cannot typify initializer list with " - f"{len(items)} items as {tc.target_type}" - ) - else: - items_tc.apply_dtype(tc.target_type.base_type) - tc.infer_dtype(expr) - else: - arr_type = PsArrayType(items_tc.target_type, len(items)) - tc.apply_dtype(arr_type, expr) + case PsArrayInitList(_): + raise TypificationError( + "Unable to typify array initializer in isolation.\n" + f" Array: {expr}" + ) case PsCast(dtype, arg): - self.visit_expr(arg, TypeContext()) + arg_tc = TypeContext() + self.visit_expr(arg, arg_tc) + + if arg_tc.target_type is None: + raise TypificationError( + f"Unable to determine type of argument to Cast: {arg}" + ) + tc.apply_dtype(dtype, expr) case _: raise NotImplementedError(f"Can't typify {expr}") + + def _handle_idx(self, idx: PsExpression): + index_tc = TypeContext() + self.visit_expr(idx, index_tc) + + if index_tc.target_type is None: + index_tc.apply_dtype(self._ctx.index_dtype, idx) + elif not isinstance(index_tc.target_type, PsIntegerType): + raise TypificationError( + f"Invalid data type in index expression.\n" + f" Expression: {idx}\n" + f" Type: {index_tc.target_type}" + ) diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py index a3213350e2c21b91a9e3c41f2e582992f293d7fb..8868179307fa8a76bec5049ea8cf05ac3a4b46e0 100644 --- a/src/pystencils/backend/kernelfunction.py +++ b/src/pystencils/backend/kernelfunction.py @@ -1,15 +1,21 @@ from __future__ import annotations from warnings import warn -from abc import ABC from typing import Callable, Sequence, Iterable, Any, TYPE_CHECKING +from itertools import chain from .._deprecation import _deprecated from .ast.structural import PsBlock from .ast.analysis import collect_required_headers, collect_undefined_symbols -from .arrays import PsArrayShapeSymbol, PsArrayStrideSymbol, PsArrayBasePointer -from .symbols import PsSymbol +from .memory import PsSymbol +from .properties import ( + PsSymbolProperty, + _FieldProperty, + FieldShape, + FieldStride, + FieldBasePtr, +) from .kernelcreation.context import KernelCreationContext from .platforms import Platform, GpuThreadsRange @@ -25,11 +31,29 @@ if TYPE_CHECKING: class KernelParameter: - __match_args__ = ("name", "dtype") + """Parameter to a `KernelFunction`.""" - def __init__(self, name: str, dtype: PsType): + __match_args__ = ("name", "dtype", "properties") + + def __init__( + self, name: str, dtype: PsType, properties: Iterable[PsSymbolProperty] = () + ): self._name = name self._dtype = dtype + self._properties: frozenset[PsSymbolProperty] = ( + frozenset(properties) if properties is not None else frozenset() + ) + self._fields: tuple[Field, ...] = tuple( + sorted( + set( + p.field # type: ignore + for p in filter( + lambda p: isinstance(p, _FieldProperty), self._properties + ) + ), + key=lambda f: f.name + ) + ) @property def name(self): @@ -40,8 +64,9 @@ class KernelParameter: return self._dtype def _hashable_contents(self): - return (self._name, self._dtype) + return (self._name, self._dtype, self._properties) + # TODO: Need? def __hash__(self) -> int: return hash(self._hashable_contents()) @@ -64,110 +89,63 @@ class KernelParameter: def symbol(self) -> TypedSymbol: return TypedSymbol(self.name, self.dtype) + @property + def fields(self) -> Sequence[Field]: + """Set of fields associated with this parameter.""" + return self._fields + + def get_properties( + self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...] + ) -> set[PsSymbolProperty]: + """Retrieve all properties of the given type(s) attached to this parameter""" + return set(filter(lambda p: isinstance(p, prop_type), self._properties)) + + @property + def properties(self) -> frozenset[PsSymbolProperty]: + return self._properties + @property def is_field_parameter(self) -> bool: - warn( - "`is_field_parameter` is deprecated and will be removed in a future version of pystencils. " - "Use `isinstance(param, FieldParameter)` instead.", - DeprecationWarning, - ) - return isinstance(self, FieldParameter) + return bool(self._fields) + + # Deprecated legacy properties + # These are kept mostly for the legacy waLBerla code generation system @property def is_field_pointer(self) -> bool: warn( "`is_field_pointer` is deprecated and will be removed in a future version of pystencils. " - "Use `isinstance(param, FieldPointerParam)` instead.", + "Use `param.get_properties(FieldBasePtr)` instead.", DeprecationWarning, ) - return isinstance(self, FieldPointerParam) + return bool(self.get_properties(FieldBasePtr)) @property def is_field_stride(self) -> bool: warn( "`is_field_stride` is deprecated and will be removed in a future version of pystencils. " - "Use `isinstance(param, FieldStrideParam)` instead.", + "Use `param.get_properties(FieldStride)` instead.", DeprecationWarning, ) - return isinstance(self, FieldStrideParam) + return bool(self.get_properties(FieldStride)) @property def is_field_shape(self) -> bool: warn( "`is_field_shape` is deprecated and will be removed in a future version of pystencils. " - "Use `isinstance(param, FieldShapeParam)` instead.", - DeprecationWarning, - ) - return isinstance(self, FieldShapeParam) - - -class FieldParameter(KernelParameter, ABC): - __match_args__ = KernelParameter.__match_args__ + ("field",) - - def __init__(self, name: str, dtype: PsType, field: Field): - super().__init__(name, dtype) - self._field = field - - @property - def field(self): - return self._field - - @property - def fields(self): - warn( - "`fields` is deprecated and will be removed in a future version of pystencils. " - "In pystencils >= 2.0, field parameters are only associated with a single field." - "Use the `field` property instead.", + "Use `param.get_properties(FieldShape)` instead.", DeprecationWarning, ) - return [self._field] + return bool(self.get_properties(FieldShape)) @property def field_name(self) -> str: warn( "`field_name` is deprecated and will be removed in a future version of pystencils. " - "Use `field.name` instead.", + "Use `param.fields[0].name` instead.", DeprecationWarning, ) - return self._field.name - - def _hashable_contents(self): - return super()._hashable_contents() + (self._field,) - - -class FieldShapeParam(FieldParameter): - __match_args__ = FieldParameter.__match_args__ + ("coordinate",) - - def __init__(self, name: str, dtype: PsType, field: Field, coordinate: int): - super().__init__(name, dtype, field) - self._coordinate = coordinate - - @property - def coordinate(self): - return self._coordinate - - def _hashable_contents(self): - return super()._hashable_contents() + (self._coordinate,) - - -class FieldStrideParam(FieldParameter): - __match_args__ = FieldParameter.__match_args__ + ("coordinate",) - - def __init__(self, name: str, dtype: PsType, field: Field, coordinate: int): - super().__init__(name, dtype, field) - self._coordinate = coordinate - - @property - def coordinate(self): - return self._coordinate - - def _hashable_contents(self): - return super()._hashable_contents() + (self._coordinate,) - - -class FieldPointerParam(FieldParameter): - def __init__(self, name: str, dtype: PsType, field: Field): - super().__init__(name, dtype, field) + return self._fields[0].name class KernelFunction: @@ -236,7 +214,7 @@ class KernelFunction: return self.parameters def get_fields(self) -> set[Field]: - return set(p.field for p in self._params if isinstance(p, FieldParameter)) + return set(chain.from_iterable(p.fields for p in self._params)) @property def fields_accessed(self) -> set[Field]: @@ -333,19 +311,19 @@ def create_gpu_kernel_function( def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]): params: list[KernelParameter] = [] + + from pystencils.backend.memory import BufferBasePtr + for symb in symbols: - match symb: - case PsArrayShapeSymbol(name, _, arr, coord): - field = ctx.find_field(arr.name) - params.append(FieldShapeParam(name, symb.get_dtype(), field, coord)) - case PsArrayStrideSymbol(name, _, arr, coord): - field = ctx.find_field(arr.name) - params.append(FieldStrideParam(name, symb.get_dtype(), field, coord)) - case PsArrayBasePointer(name, _, arr): - field = ctx.find_field(arr.name) - params.append(FieldPointerParam(name, symb.get_dtype(), field)) - case PsSymbol(name, _): - params.append(KernelParameter(name, symb.get_dtype())) + props: set[PsSymbolProperty] = set() + for prop in symb.properties: + match prop: + case FieldShape() | FieldStride(): + props.add(prop) + case BufferBasePtr(buf): + field = ctx.find_field(buf.name) + props.add(FieldBasePtr(field)) + params.append(KernelParameter(symb.name, symb.get_dtype(), props)) params.sort(key=lambda p: p.name) return params diff --git a/src/pystencils/backend/literals.py b/src/pystencils/backend/literals.py index dc254da0e340d518929d6eecb483defcdffbe185..976e6b2030d2350a0c7105f8a3a17cfcccf393fd 100644 --- a/src/pystencils/backend/literals.py +++ b/src/pystencils/backend/literals.py @@ -6,7 +6,7 @@ class PsLiteral: """Representation of literal code. Instances of this class represent code literals inside the AST. - These literals are not to be confused with C literals; the name `Literal` refers to the fact that + These literals are not to be confused with C literals; the name 'Literal' refers to the fact that the code generator takes them "literally", printing them as they are. Each literal has to be annotated with a type, and is considered constant within the scope of a kernel. diff --git a/src/pystencils/backend/memory.py b/src/pystencils/backend/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..9b72a4e4337f1de152291b3287cffb612999a786 --- /dev/null +++ b/src/pystencils/backend/memory.py @@ -0,0 +1,198 @@ +from __future__ import annotations +from typing import Sequence +from itertools import chain +from dataclasses import dataclass + +from ..types import PsType, PsTypeError, deconstify, PsIntegerType, PsPointerType +from .exceptions import PsInternalCompilerError +from .constants import PsConstant +from .properties import PsSymbolProperty, UniqueSymbolProperty + + +class PsSymbol: + """A mutable symbol with name and data type. + + Do not create objects of this class directly unless you know what you are doing; + instead obtain them from a `KernelCreationContext` through `KernelCreationContext.get_symbol`. + This way, the context can keep track of all symbols used in the translation run, + and uniqueness of symbols is ensured. + """ + + __match_args__ = ("name", "dtype") + + def __init__(self, name: str, dtype: PsType | None = None): + self._name = name + self._dtype = dtype + self._properties: set[PsSymbolProperty] = set() + + @property + def name(self) -> str: + return self._name + + @property + def dtype(self) -> PsType | None: + return self._dtype + + @dtype.setter + def dtype(self, value: PsType): + self._dtype = value + + def apply_dtype(self, dtype: PsType): + """Apply the given data type to this symbol, + raising a TypeError if it conflicts with a previously set data type.""" + + if self._dtype is not None and self._dtype != dtype: + raise PsTypeError( + f"Incompatible symbol data types: {self._dtype} and {dtype}" + ) + + self._dtype = dtype + + def get_dtype(self) -> PsType: + if self._dtype is None: + raise PsInternalCompilerError( + f"Symbol {self.name} had no type assigned yet" + ) + return self._dtype + + @property + def properties(self) -> frozenset[PsSymbolProperty]: + """Set of properties attached to this symbol""" + return frozenset(self._properties) + + def get_properties( + self, prop_type: type[PsSymbolProperty] + ) -> set[PsSymbolProperty]: + """Retrieve all properties of the given type attached to this symbol""" + return set(filter(lambda p: isinstance(p, prop_type), self._properties)) + + def add_property(self, property: PsSymbolProperty): + """Attach a property to this symbol""" + if isinstance(property, UniqueSymbolProperty) and not self.get_properties( + type(property) + ) <= {property}: + raise ValueError( + f"Cannot add second instance of unique property {type(property)} to symbol {self._name}." + ) + + self._properties.add(property) + + def remove_property(self, property: PsSymbolProperty): + """Remove a property from this symbol. Does nothing if the property is not attached.""" + self._properties.discard(property) + + def __str__(self) -> str: + dtype_str = "<untyped>" if self._dtype is None else str(self._dtype) + return f"{self._name}: {dtype_str}" + + def __repr__(self) -> str: + return f"PsSymbol({repr(self._name)}, {repr(self._dtype)})" + + +@dataclass(frozen=True) +class BufferBasePtr(UniqueSymbolProperty): + """Symbol acts as a base pointer to a buffer.""" + + buffer: PsBuffer + + +class PsBuffer: + """N-dimensional contiguous linearized buffer in heap memory. + + `PsBuffer` models the memory buffers underlying the `Field` class + to the backend. Each buffer represents a contiguous block of memory + that is non-aliased and disjoint from all other buffers. + + Buffer shape and stride information are given either as constants or as symbols. + All indexing expressions must have the same data type, which will be selected as the buffer's + `index_dtype`. + + Each buffer has at least one base pointer, which can be retrieved via the `PsBuffer.base_pointer` + property. + """ + + def __init__( + self, + name: str, + element_type: PsType, + base_ptr: PsSymbol, + shape: Sequence[PsSymbol | PsConstant], + strides: Sequence[PsSymbol | PsConstant], + ): + bptr_type = base_ptr.get_dtype() + + if not isinstance(bptr_type, PsPointerType): + raise ValueError( + f"Type of buffer base pointer {base_ptr} was not a pointer type: {bptr_type}" + ) + + if bptr_type.base_type != element_type: + raise ValueError( + f"Base type of primary buffer base pointer {base_ptr} " + f"did not equal buffer element type {element_type}." + ) + + if len(shape) != len(strides): + raise ValueError("Buffer shape and stride tuples must have the same length") + + idx_types: set[PsType] = set( + deconstify(s.get_dtype()) for s in chain(shape, strides) + ) + if len(idx_types) > 1: + raise ValueError( + f"Conflicting data types in indexing symbols to buffer {name}: {idx_types}" + ) + + idx_dtype = idx_types.pop() + if not isinstance(idx_dtype, PsIntegerType): + raise ValueError( + f"Invalid index data type for buffer {name}: {idx_dtype}. Must be an integer type." + ) + + self._name = name + self._element_type = element_type + self._index_dtype = idx_dtype + + self._shape = tuple(shape) + self._strides = tuple(strides) + + base_ptr.add_property(BufferBasePtr(self)) + self._base_ptr = base_ptr + + @property + def name(self): + """The buffer's name""" + return self._name + + @property + def base_pointer(self) -> PsSymbol: + """Primary base pointer""" + return self._base_ptr + + @property + def shape(self) -> tuple[PsSymbol | PsConstant, ...]: + """Buffer shape symbols and/or constants""" + return self._shape + + @property + def strides(self) -> tuple[PsSymbol | PsConstant, ...]: + """Buffer stride symbols and/or constants""" + return self._strides + + @property + def dim(self) -> int: + """Dimensionality of this buffer""" + return len(self._shape) + + @property + def index_type(self) -> PsIntegerType: + """Index data type of this buffer; i.e. data type of its shape and stride symbols""" + return self._index_dtype + + @property + def element_type(self) -> PsType: + """Element type of this buffer""" + return self._element_type + + def __repr__(self) -> str: + return f"PsBuffer({self._name}: {self.element_type}[{len(self.shape)}D])" diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index c89d2278881bc48cd81df44ab49adefda4fc99f3..323dcc5a9b306990a11f85037f305794c8729625 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -7,6 +7,7 @@ from ..kernelcreation import ( IterationSpace, FullIterationSpace, SparseIterationSpace, + AstFactory ) from ..kernelcreation.context import KernelCreationContext @@ -17,7 +18,7 @@ from ..ast.expressions import ( PsCast, PsCall, PsLookup, - PsArrayAccess, + PsBufferAcc, ) from ..ast.expressions import PsLt, PsAnd from ...types import PsSignedIntegerType, PsIeeeFloatType @@ -159,6 +160,7 @@ class CudaPlatform(GenericGpu): def _prepend_sparse_translation( self, body: PsBlock, ispace: SparseIterationSpace ) -> tuple[PsBlock, GpuThreadsRange]: + factory = AstFactory(self._ctx) ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype()) sparse_ctr = PsExpression.make(ispace.sparse_counter) @@ -171,9 +173,9 @@ class CudaPlatform(GenericGpu): PsDeclaration( PsExpression.make(ctr), PsLookup( - PsArrayAccess( + PsBufferAcc( ispace.index_list.base_pointer, - sparse_ctr, + (sparse_ctr, factory.parse_index(0)), ), coord.name, ), diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index a1505e672862264a50a8b035a83dc8dcdfb0769d..f8cae89fcb8f0f2310a83049c1f3453ba9329b39 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -21,8 +21,8 @@ from ..ast.structural import PsDeclaration, PsLoop, PsBlock from ..ast.expressions import ( PsSymbolExpr, PsExpression, - PsArrayAccess, - PsVectorArrayAccess, + PsBufferAcc, + PsVectorMemAcc, PsLookup, PsGe, PsLe, @@ -124,13 +124,15 @@ class GenericCpu(Platform): return PsBlock([loops]) def _create_sparse_loop(self, body: PsBlock, ispace: SparseIterationSpace): + factory = AstFactory(self._ctx) + mappings = [ PsDeclaration( PsSymbolExpr(ctr), PsLookup( - PsArrayAccess( + PsBufferAcc( ispace.index_list.base_pointer, - PsExpression.make(ispace.sparse_counter), + (PsExpression.make(ispace.sparse_counter), factory.parse_index(0)), ), coord.name, ), @@ -173,11 +175,11 @@ class GenericVectorCpu(GenericCpu, ABC): or raise an `MaterializationError` if not supported.""" @abstractmethod - def vector_load(self, acc: PsVectorArrayAccess) -> PsExpression: + def vector_load(self, acc: PsVectorMemAcc) -> PsExpression: """Return an expression intrinsically performing a vector load, or raise an `MaterializationError` if not supported.""" @abstractmethod - def vector_store(self, acc: PsVectorArrayAccess, arg: PsExpression) -> PsExpression: + def vector_store(self, acc: PsVectorMemAcc, arg: PsExpression) -> PsExpression: """Return an expression intrinsically performing a vector store, or raise an `MaterializationError` if not supported.""" diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py index 39ec09992765f388c31ec1cb003a4d5e7849fbca..ec5e7eda05d0417a764d26294206c6c0dcf7d02d 100644 --- a/src/pystencils/backend/platforms/sycl.py +++ b/src/pystencils/backend/platforms/sycl.py @@ -16,11 +16,11 @@ from ..ast.expressions import ( PsLe, PsTernary, PsLookup, - PsArrayAccess + PsBufferAcc ) from ..extensions.cpp import CppMethodCall -from ..kernelcreation.context import KernelCreationContext +from ..kernelcreation import KernelCreationContext, AstFactory from ..constants import PsConstant from .generic_gpu import GenericGpu, GpuThreadsRange from ..exceptions import MaterializationError @@ -121,7 +121,7 @@ class SyclPlatform(GenericGpu): for i, dim in enumerate(dimensions): # Slowest to fastest coord = PsExpression.make(PsConstant(i, self._ctx.index_dtype)) - work_item_idx = PsSubscript(id_symbol, coord) + work_item_idx = PsSubscript(id_symbol, (coord,)) dim.counter.dtype = constify(dim.counter.get_dtype()) work_item_idx.dtype = dim.counter.get_dtype() @@ -147,11 +147,13 @@ class SyclPlatform(GenericGpu): def _prepend_sparse_translation( self, body: PsBlock, ispace: SparseIterationSpace ) -> tuple[PsBlock, GpuThreadsRange]: + factory = AstFactory(self._ctx) + id_type = PsCustomType("sycl::id< 1 >", const=True) id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type)) zero = PsExpression.make(PsConstant(0, self._ctx.index_dtype)) - subscript = PsSubscript(id_symbol, zero) + subscript = PsSubscript(id_symbol, (zero,)) ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype()) subscript.dtype = ispace.sparse_counter.get_dtype() @@ -163,9 +165,9 @@ class SyclPlatform(GenericGpu): PsDeclaration( PsExpression.make(ctr), PsLookup( - PsArrayAccess( + PsBufferAcc( ispace.index_list.base_pointer, - sparse_ctr, + (sparse_ctr, factory.parse_index(0)), ), coord.name, ), diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index ccaf9fbe99f46ce4b0ecbb81c775c9f274678026..33838df08bcdb13094a96387fd3db565e4ba5932 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -5,9 +5,9 @@ from typing import Sequence from ..ast.expressions import ( PsExpression, - PsVectorArrayAccess, + PsVectorMemAcc, PsAddressOf, - PsSubscript, + PsMemAcc, ) from ..transformations.select_intrinsics import IntrinsicOps from ...types import PsCustomType, PsVectorType, PsPointerType @@ -141,20 +141,20 @@ class X86VectorCpu(GenericVectorCpu): func = _x86_op_intrin(self._vector_arch, op, vtype) return func(*args) - def vector_load(self, acc: PsVectorArrayAccess) -> PsExpression: + def vector_load(self, acc: PsVectorMemAcc) -> PsExpression: if acc.stride == 1: load_func = _x86_packed_load(self._vector_arch, acc.dtype, False) return load_func( - PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)) + PsAddressOf(PsMemAcc(acc.pointer, acc.offset)) ) else: raise NotImplementedError("Gather loads not implemented yet.") - def vector_store(self, acc: PsVectorArrayAccess, arg: PsExpression) -> PsExpression: + def vector_store(self, acc: PsVectorMemAcc, arg: PsExpression) -> PsExpression: if acc.stride == 1: store_func = _x86_packed_store(self._vector_arch, acc.dtype, False) return store_func( - PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)), + PsAddressOf(PsMemAcc(acc.pointer, acc.offset)), arg, ) else: diff --git a/src/pystencils/backend/properties.py b/src/pystencils/backend/properties.py new file mode 100644 index 0000000000000000000000000000000000000000..d377fb3d35d99b59c4f364cc4d066b736bfd9140 --- /dev/null +++ b/src/pystencils/backend/properties.py @@ -0,0 +1,41 @@ +from __future__ import annotations +from dataclasses import dataclass + +from ..field import Field + + +@dataclass(frozen=True) +class PsSymbolProperty: + """Base class for symbol properties, which can be used to add additional information to symbols""" + + +@dataclass(frozen=True) +class UniqueSymbolProperty(PsSymbolProperty): + """Base class for unique properties, of which only one instance may be registered at a time.""" + + +@dataclass(frozen=True) +class FieldShape(PsSymbolProperty): + """Symbol acts as a shape parameter to a field.""" + + field: Field + coordinate: int + + +@dataclass(frozen=True) +class FieldStride(PsSymbolProperty): + """Symbol acts as a stride parameter to a field.""" + + field: Field + coordinate: int + + +@dataclass(frozen=True) +class FieldBasePtr(UniqueSymbolProperty): + """Symbol acts as a base pointer to a field.""" + + field: Field + + +FieldProperty = FieldShape | FieldStride | FieldBasePtr +_FieldProperty = (FieldShape, FieldStride, FieldBasePtr) diff --git a/src/pystencils/backend/symbols.py b/src/pystencils/backend/symbols.py deleted file mode 100644 index b007e3fcf4791e89bb34829f0c6e4b7d1dbbbd21..0000000000000000000000000000000000000000 --- a/src/pystencils/backend/symbols.py +++ /dev/null @@ -1,55 +0,0 @@ -from ..types import PsType, PsTypeError -from .exceptions import PsInternalCompilerError - - -class PsSymbol: - """A mutable symbol with name and data type. - - Do not create objects of this class directly unless you know what you are doing; - instead obtain them from a `KernelCreationContext` through `KernelCreationContext.get_symbol`. - This way, the context can keep track of all symbols used in the translation run, - and uniqueness of symbols is ensured. - """ - - __match_args__ = ("name", "dtype") - - def __init__(self, name: str, dtype: PsType | None = None): - self._name = name - self._dtype = dtype - - @property - def name(self) -> str: - return self._name - - @property - def dtype(self) -> PsType | None: - return self._dtype - - @dtype.setter - def dtype(self, value: PsType): - self._dtype = value - - def apply_dtype(self, dtype: PsType): - """Apply the given data type to this symbol, - raising a TypeError if it conflicts with a previously set data type.""" - - if self._dtype is not None and self._dtype != dtype: - raise PsTypeError( - f"Incompatible symbol data types: {self._dtype} and {dtype}" - ) - - self._dtype = dtype - - def get_dtype(self) -> PsType: - if self._dtype is None: - raise PsInternalCompilerError( - f"Symbol {self.name} had no type assigned yet" - ) - return self._dtype - - def __str__(self) -> str: - dtype_str = "<untyped>" if self._dtype is None else str(self._dtype) - return f"{self._name}: {dtype_str}" - - def __repr__(self) -> str: - return f"PsSymbol({self._name}, {self._dtype})" diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index 88ad9348f09685258d2aecb5fca66fcfe609173b..7375af618a438d9145e76fe2097d7176b7d2b2ea 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -69,7 +69,7 @@ Loop Reshaping Transformations Code Lowering and Materialization --------------------------------- -.. autoclass:: EraseAnonymousStructTypes +.. autoclass:: LowerToC :members: __call__ .. autoclass:: SelectFunctions @@ -84,7 +84,7 @@ from .eliminate_branches import EliminateBranches from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .reshape_loops import ReshapeLoops from .add_pragmas import InsertPragmasAtLoops, LoopPragma, AddOpenMP -from .erase_anonymous_structs import EraseAnonymousStructTypes +from .lower_to_c import LowerToC from .select_functions import SelectFunctions from .select_intrinsics import MaterializeVectorIntrinsics @@ -98,7 +98,7 @@ __all__ = [ "InsertPragmasAtLoops", "LoopPragma", "AddOpenMP", - "EraseAnonymousStructTypes", + "LowerToC", "SelectFunctions", "MaterializeVectorIntrinsics", ] diff --git a/src/pystencils/backend/transformations/canonical_clone.py b/src/pystencils/backend/transformations/canonical_clone.py index b21fd115f98645ff4c8dfb2dd3f72c252282fcf2..2cf9bcf0c8d95f16e9935ed1754f853f13516cc8 100644 --- a/src/pystencils/backend/transformations/canonical_clone.py +++ b/src/pystencils/backend/transformations/canonical_clone.py @@ -1,7 +1,7 @@ from typing import TypeVar, cast from ..kernelcreation import KernelCreationContext -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..exceptions import PsInternalCompilerError from ..ast import PsAstNode diff --git a/src/pystencils/backend/transformations/canonicalize_symbols.py b/src/pystencils/backend/transformations/canonicalize_symbols.py index e55807ef4bd7726d173b6028997778538b66053b..f5b356432a56cc8c2a33eba6ad533947b9f2b2ad 100644 --- a/src/pystencils/backend/transformations/canonicalize_symbols.py +++ b/src/pystencils/backend/transformations/canonicalize_symbols.py @@ -1,5 +1,5 @@ from ..kernelcreation import KernelCreationContext -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..exceptions import PsInternalCompilerError from ..ast import PsAstNode diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index bd3b2bb5871b59dd1e0431ed3d5ee7d26c4ff7a6..222f4a378c3063cf58e1d14f714a5ccbdc524964 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -34,7 +34,7 @@ from ..ast.expressions import ( from ..ast.util import AstEqWrapper from ..constants import PsConstant -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..functions import PsMathFunction from ...types import ( PsIntegerType, diff --git a/src/pystencils/backend/transformations/erase_anonymous_structs.py b/src/pystencils/backend/transformations/erase_anonymous_structs.py deleted file mode 100644 index 03d79a68972b974dd4f82a0f6fae55ef32d68395..0000000000000000000000000000000000000000 --- a/src/pystencils/backend/transformations/erase_anonymous_structs.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import annotations - -from ..kernelcreation.context import KernelCreationContext - -from ..constants import PsConstant -from ..ast.structural import PsAstNode -from ..ast.expressions import ( - PsArrayAccess, - PsLookup, - PsExpression, - PsDeref, - PsAddressOf, - PsCast, -) -from ..kernelcreation import Typifier -from ..arrays import PsArrayBasePointer, TypeErasedBasePointer -from ...types import PsStructType, PsPointerType - - -class EraseAnonymousStructTypes: - """Lower anonymous struct arrays to a byte-array representation. - - For arrays whose element type is an anonymous struct, the struct type is erased from the base pointer, - making it a pointer to uint8_t. - Member lookups on accesses into these arrays are then transformed using type casts. - """ - - def __init__(self, ctx: KernelCreationContext) -> None: - self._ctx = ctx - - self._substitutions: dict[PsArrayBasePointer, TypeErasedBasePointer] = dict() - - def __call__(self, node: PsAstNode) -> PsAstNode: - self._substitutions = dict() - - # Check if AST traversal is even necessary - if not any( - (isinstance(arr.element_type, PsStructType) and arr.element_type.anonymous) - for arr in self._ctx.arrays - ): - return node - - node = self.visit(node) - - for old, new in self._substitutions.items(): - self._ctx.replace_symbol(old, new) - - return node - - def visit(self, node: PsAstNode) -> PsAstNode: - match node: - case PsLookup(): - # descend into expr - return self.handle_lookup(node) - case _: - node.children = [self.visit(c) for c in node.children] - - return node - - def handle_lookup(self, lookup: PsLookup) -> PsExpression: - aggr = lookup.aggregate - if not isinstance(aggr, PsArrayAccess): - return lookup - - arr = aggr.array - if ( - not isinstance(arr.element_type, PsStructType) - or not arr.element_type.anonymous - ): - return lookup - - struct_type = arr.element_type - struct_size = struct_type.itemsize - - bp = aggr.base_ptr - - # Need to keep track of base pointers already seen, since symbols must be unique - if bp not in self._substitutions: - type_erased_bp = TypeErasedBasePointer(bp.name, arr) - self._substitutions[bp] = type_erased_bp - else: - type_erased_bp = self._substitutions[bp] - - base_index = aggr.index * PsExpression.make( - PsConstant(struct_size, self._ctx.index_dtype) - ) - - member_name = lookup.member_name - member = struct_type.find_member(member_name) - assert member is not None - - np_struct = struct_type.numpy_dtype - assert np_struct is not None - assert np_struct.fields is not None - member_offset = np_struct.fields[member_name][1] - - byte_index = base_index + PsExpression.make( - PsConstant(member_offset, self._ctx.index_dtype) - ) - type_erased_access = PsArrayAccess(type_erased_bp, byte_index) - - deref = PsDeref( - PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access)) - ) - - typify = Typifier(self._ctx) - deref = typify(deref) - return deref diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index 2368868a99ec36b817f0f6fe6aa8ef55608ac0e1..f0e4cc9f19f1a046125bb3e8aab5302a9df2790c 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -9,15 +9,17 @@ from ..ast.expressions import ( PsConstantExpr, PsLiteralExpr, PsCall, - PsDeref, + PsBufferAcc, PsSubscript, + PsLookup, PsUnOp, PsBinOp, PsArrayInitList, ) +from ..ast.util import determine_memory_object from ...types import PsDereferencableType -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..functions import PsMathFunction __all__ = ["HoistLoopInvariantDeclarations"] @@ -48,7 +50,11 @@ class HoistContext: case PsCall(func): return isinstance(func, PsMathFunction) and args_invariant(expr) - case PsSubscript(ptr, _) | PsDeref(ptr): + case PsSubscript() | PsLookup(): + return determine_memory_object(expr)[1] and args_invariant(expr) + + case PsBufferAcc(ptr, _): + # Regular pointer derefs are never invariant, since we cannot reason about aliasing ptr_type = cast(PsDereferencableType, ptr.get_dtype()) return ptr_type.base_type.const and args_invariant(expr) diff --git a/src/pystencils/backend/transformations/lower_to_c.py b/src/pystencils/backend/transformations/lower_to_c.py new file mode 100644 index 0000000000000000000000000000000000000000..ea832355bb1a53f94fc07cad670f86f98e5f6a2e --- /dev/null +++ b/src/pystencils/backend/transformations/lower_to_c.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from typing import cast +from functools import reduce +import operator + +from ..kernelcreation import KernelCreationContext, Typifier + +from ..constants import PsConstant +from ..memory import PsSymbol, PsBuffer, BufferBasePtr +from ..ast.structural import PsAstNode +from ..ast.expressions import ( + PsBufferAcc, + PsLookup, + PsExpression, + PsMemAcc, + PsAddressOf, + PsCast, + PsSymbolExpr, +) +from ...types import PsStructType, PsPointerType, PsUnsignedIntegerType + + +class LowerToC: + """Lower high-level IR constructs to C language concepts. + + This pass will replace a number of IR constructs that have no direct counterpart in the C language + to lower-level AST nodes. These include: + + - *Linearization of Buffer Accesses:* `PsBufferAcc` buffer accesses are linearized according to + their buffers' stride information and replaced by `PsMemAcc`. + - *Erasure of Anonymous Structs:* + For buffers whose element type is an anonymous struct, the struct type is erased from the base pointer, + making it a pointer to uint8_t. + Member lookups on accesses into these buffers are then transformed using type casts. + """ + + def __init__(self, ctx: KernelCreationContext) -> None: + self._ctx = ctx + self._typify = Typifier(ctx) + + self._substitutions: dict[PsSymbol, PsSymbol] = dict() + + def __call__(self, node: PsAstNode) -> PsAstNode: + self._substitutions = dict() + + node = self.visit(node) + + for old, new in self._substitutions.items(): + self._ctx.replace_symbol(old, new) + + return node + + def visit(self, node: PsAstNode) -> PsAstNode: + match node: + case PsBufferAcc(bptr, indices): + # Linearize + buf = node.buffer + + # Typifier allows different data types in each index + def maybe_cast(i: PsExpression): + if i.get_dtype() != buf.index_type: + return PsCast(buf.index_type, i) + else: + return i + + summands: list[PsExpression] = [ + maybe_cast(cast(PsExpression, self.visit(idx))) * PsExpression.make(stride) + for idx, stride in zip(indices, buf.strides, strict=True) + ] + + linearized_idx: PsExpression = ( + summands[0] + if len(summands) == 1 + else reduce(operator.add, summands) + ) + + mem_acc = PsMemAcc(bptr, linearized_idx) + + return self._typify.typify_expression( + mem_acc, target_type=buf.element_type + )[0] + + case PsLookup(aggr, member_name) if isinstance( + aggr, PsBufferAcc + ) and isinstance( + aggr.buffer.element_type, PsStructType + ) and aggr.buffer.element_type.anonymous: + # Need to lower this buffer-lookup + linearized_acc = self.visit(aggr) + return self._lower_anon_lookup( + cast(PsMemAcc, linearized_acc), aggr.buffer, member_name + ) + + case _: + node.children = [self.visit(c) for c in node.children] + + return node + + def _lower_anon_lookup( + self, aggr: PsMemAcc, buf: PsBuffer, member_name: str + ) -> PsExpression: + struct_type = cast(PsStructType, buf.element_type) + struct_size = struct_type.itemsize + + assert isinstance(aggr.pointer, PsSymbolExpr) + bp = aggr.pointer.symbol + bp_type = bp.get_dtype() + assert isinstance(bp_type, PsPointerType) + + # Need to keep track of base pointers already seen, since symbols must be unique + if bp not in self._substitutions: + erased_type = PsPointerType( + PsUnsignedIntegerType(8, const=bp_type.base_type.const), + const=bp_type.const, + restrict=bp_type.restrict, + ) + type_erased_bp = PsSymbol( + bp.name, + erased_type + ) + type_erased_bp.add_property(BufferBasePtr(buf)) + self._substitutions[bp] = type_erased_bp + else: + type_erased_bp = self._substitutions[bp] + + base_index = aggr.offset * PsExpression.make( + PsConstant(struct_size, self._ctx.index_dtype) + ) + + member = struct_type.find_member(member_name) + assert member is not None + + np_struct = struct_type.numpy_dtype + assert np_struct is not None + assert np_struct.fields is not None + member_offset = np_struct.fields[member_name][1] + + byte_index = base_index + PsExpression.make( + PsConstant(member_offset, self._ctx.index_dtype) + ) + type_erased_access = PsMemAcc(PsExpression.make(type_erased_bp), byte_index) + + deref = PsMemAcc( + PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access)), + PsExpression.make(PsConstant(0)), + ) + + deref = self._typify(deref) + return deref diff --git a/src/pystencils/backend/transformations/select_intrinsics.py b/src/pystencils/backend/transformations/select_intrinsics.py index 7972de0699f52b1230a5e9c9e00b43d5d122f61f..3fb484c154fbb4ab873deea3e9b1d83c2f4354e6 100644 --- a/src/pystencils/backend/transformations/select_intrinsics.py +++ b/src/pystencils/backend/transformations/select_intrinsics.py @@ -6,7 +6,7 @@ from ..ast.structural import PsAstNode, PsAssignment, PsStatement from ..ast.expressions import PsExpression from ...types import PsVectorType, deconstify from ..ast.expressions import ( - PsVectorArrayAccess, + PsVectorMemAcc, PsSymbolExpr, PsConstantExpr, PsBinOp, @@ -66,7 +66,7 @@ class MaterializeVectorIntrinsics: def visit(self, node: PsAstNode) -> PsAstNode: match node: - case PsAssignment(lhs, rhs) if isinstance(lhs, PsVectorArrayAccess): + case PsAssignment(lhs, rhs) if isinstance(lhs, PsVectorMemAcc): vc = VecTypeCtx() vc.set(lhs.get_vector_type()) store_arg = self.visit_expr(rhs, vc) @@ -94,7 +94,7 @@ class MaterializeVectorIntrinsics: else: return expr - case PsVectorArrayAccess(): + case PsVectorMemAcc(): vc.set(expr.get_vector_type()) return self._platform.vector_load(expr) diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py index c7657ec51e5d0386087455e8ff927ef56dff0384..d46cc7a8ea96b646cf0db274b48036d94663c79e 100644 --- a/src/pystencils/boundaries/boundaryhandling.py +++ b/src/pystencils/boundaries/boundaryhandling.py @@ -12,7 +12,7 @@ from pystencils.types import PsIntegerType from pystencils.types.quick import Arr, SInt from pystencils.gpu.gpu_array_handler import GPUArrayHandler from pystencils.field import Field, FieldType -from pystencils.backend.kernelfunction import FieldPointerParam +from pystencils.backend.properties import FieldBasePtr try: # noinspection PyPep8Naming @@ -244,9 +244,9 @@ class BoundaryHandling: for b_obj, idx_arr in b[self._index_array_name].boundary_object_to_index_list.items(): kwargs[self._field_name] = b[self._field_name] kwargs['indexField'] = idx_arr - data_used_in_kernel = (p.field.name + data_used_in_kernel = (p.fields[0].name for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters - if isinstance(p, FieldPointerParam) and p.field.name not in kwargs) + if bool(p.get_properties(FieldBasePtr)) and p.fields[0].name not in kwargs) kwargs.update({name: b[name] for name in data_used_in_kernel}) self._boundary_object_to_boundary_info[b_obj].kernel(**kwargs) @@ -260,9 +260,9 @@ class BoundaryHandling: arguments = kwargs.copy() arguments[self._field_name] = b[self._field_name] arguments['indexField'] = idx_arr - data_used_in_kernel = (p.field.name + data_used_in_kernel = (p.fields[0].name for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters - if isinstance(p, FieldPointerParam) and p.field.name not in arguments) + if bool(p.get_properties(FieldBasePtr)) and p.fields[0].name not in arguments) arguments.update({name: b[name] for name in data_used_in_kernel if name not in arguments}) kernel = self._boundary_object_to_boundary_info[b_obj].kernel diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index ae64bdea3a95db84641e91ff455825b21f772e0b..7d9ac7aa4465c264855d79ae7d56260e0dd698eb 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -20,8 +20,9 @@ from .backend.kernelcreation.iteration_space import ( from .backend.transformations import ( EliminateConstants, - EraseAnonymousStructTypes, + LowerToC, SelectFunctions, + CanonicalizeSymbols, ) from .backend.kernelfunction import ( create_cpu_kernel_function, @@ -131,7 +132,7 @@ def create_kernel( f"Code generation for target {target} not implemented" ) - # Simplifying transformations + # Fold and extract constants elim_constants = EliminateConstants(ctx, extract_constant_exprs=True) kernel_ast = cast(PsBlock, elim_constants(kernel_ast)) @@ -143,12 +144,23 @@ def create_kernel( kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) - erase_anons = EraseAnonymousStructTypes(ctx) - kernel_ast = cast(PsBlock, erase_anons(kernel_ast)) + # Lowering + lower_to_c = LowerToC(ctx) + kernel_ast = cast(PsBlock, lower_to_c(kernel_ast)) select_functions = SelectFunctions(platform) kernel_ast = cast(PsBlock, select_functions(kernel_ast)) + # Late canonicalization and constant elimination passes + # * Since lowering introduces new index calculations and indexing symbols into the AST, + # * these need to be handled here + + canonicalize = CanonicalizeSymbols(ctx, True) + kernel_ast = cast(PsBlock, canonicalize(kernel_ast)) + + late_fold_constants = EliminateConstants(ctx, extract_constant_exprs=False) + kernel_ast = cast(PsBlock, late_fold_constants(kernel_ast)) + if config.target.is_cpu(): return create_cpu_kernel_function( ctx, diff --git a/src/pystencils/symb.py b/src/pystencils/symb.py index 0c682b26113c70ca2304bc63a15a6aa7e8d8ad9f..8e293405817c2189ebe7428bc1a53bbde8ca8073 100644 --- a/src/pystencils/symb.py +++ b/src/pystencils/symb.py @@ -9,6 +9,9 @@ from .sympyextensions.integer_functions import ( int_div, int_rem, int_power_of_2, + round_to_multiple_towards_zero, + ceil_to_multiple, + div_ceil, ) __all__ = [ @@ -20,4 +23,7 @@ __all__ = [ "int_div", "int_rem", "int_power_of_2", + "round_to_multiple_towards_zero", + "ceil_to_multiple", + "div_ceil", ] diff --git a/src/pystencils/sympyextensions/integer_functions.py b/src/pystencils/sympyextensions/integer_functions.py index eb3bb4ccc79d06d54e320bb0b442ea7dad1c670a..cf25472c89cd4deda18de889e92139e7c2a28067 100644 --- a/src/pystencils/sympyextensions/integer_functions.py +++ b/src/pystencils/sympyextensions/integer_functions.py @@ -1,4 +1,5 @@ import sympy as sp +import warnings from pystencils.sympyextensions import is_integer_sequence @@ -46,17 +47,19 @@ class bitwise_or(IntegerFunctionTwoArgsMixIn): # noinspection PyPep8Naming class int_div(IntegerFunctionTwoArgsMixIn): """C-style round-to-zero integer division""" - + def _eval_op(self, arg1, arg2): from ..utils import c_intdiv + return c_intdiv(arg1, arg2) class int_rem(IntegerFunctionTwoArgsMixIn): """C-style round-to-zero integer remainder""" - + def _eval_op(self, arg1, arg2): from ..utils import c_rem + return c_rem(arg1, arg2) @@ -68,66 +71,65 @@ class int_power_of_2(IntegerFunctionTwoArgsMixIn): # noinspection PyPep8Naming -class modulo_floor(sp.Function): - """Returns the next smaller integer divisible by given divisor. +class round_to_multiple_towards_zero(IntegerFunctionTwoArgsMixIn): + """Returns the next smaller/equal in magnitude integer divisible by given + divisor. Examples: - >>> modulo_floor(9, 4) + >>> round_to_multiple_towards_zero(9, 4) 8 - >>> modulo_floor(11, 4) + >>> round_to_multiple_towards_zero(11, -4) 8 - >>> modulo_floor(12, 4) + >>> round_to_multiple_towards_zero(12, 4) 12 + >>> round_to_multiple_towards_zero(-9, 4) + -8 + >>> round_to_multiple_towards_zero(-9, -4) + -8 """ - nargs = 2 - is_integer = True - def __new__(cls, integer, divisor): - if is_integer_sequence((integer, divisor)): - return (int(integer) // int(divisor)) * divisor - else: - return super().__new__(cls, integer, divisor) + @classmethod + def eval(cls, arg1, arg2): + from ..utils import c_intdiv - # TODO: Implement this in FreezeExpressions - # def to_c(self, print_func): - # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) - # assert dtype.is_int() - # return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]), - # print_func(self.args[1]), dtype=dtype) + if is_integer_sequence((arg1, arg2)): + return c_intdiv(arg1, arg2) * arg2 + + def _eval_op(self, arg1, arg2): + return self.eval(arg1, arg2) # noinspection PyPep8Naming -class modulo_ceil(sp.Function): - """Returns the next bigger integer divisible by given divisor. +class ceil_to_multiple(IntegerFunctionTwoArgsMixIn): + """For positive input, returns the next greater/equal integer divisible + by given divisor. The return value is unspecified if either argument is + negative. Examples: - >>> modulo_ceil(9, 4) + >>> ceil_to_multiple(9, 4) 12 - >>> modulo_ceil(11, 4) + >>> ceil_to_multiple(11, 4) 12 - >>> modulo_ceil(12, 4) + >>> ceil_to_multiple(12, 4) 12 """ - nargs = 2 - is_integer = True - def __new__(cls, integer, divisor): - if is_integer_sequence((integer, divisor)): - return integer if integer % divisor == 0 else ((integer // divisor) + 1) * divisor - else: - return super().__new__(cls, integer, divisor) + @classmethod + def eval(cls, arg1, arg2): + from ..utils import c_intdiv - # TODO: Implement this in FreezeExpressions - # def to_c(self, print_func): - # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) - # assert dtype.is_int() - # code = "(({0}) % ({1}) == 0 ? {0} : (({dtype})(({0}) / ({1}))+1) * ({1}))" - # return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) + if is_integer_sequence((arg1, arg2)): + return c_intdiv(arg1 + arg2 - 1, arg2) * arg2 + + def _eval_op(self, arg1, arg2): + return self.eval(arg1, arg2) # noinspection PyPep8Naming -class div_ceil(sp.Function): - """Integer division that is always rounded up +class div_ceil(IntegerFunctionTwoArgsMixIn): + """For positive input, integer division that is always rounded up, i.e. + `div_ceil(a, b) = ceil(div(a, b))`. The return value is unspecified if + either argument is negative. Examples: >>> div_ceil(9, 4) @@ -135,45 +137,46 @@ class div_ceil(sp.Function): >>> div_ceil(8, 4) 2 """ - nargs = 2 - is_integer = True - def __new__(cls, integer, divisor): - if is_integer_sequence((integer, divisor)): - return integer // divisor if integer % divisor == 0 else (integer // divisor) + 1 - else: - return super().__new__(cls, integer, divisor) + @classmethod + def eval(cls, arg1, arg2): + from ..utils import c_intdiv - # TODO: Implement this in FreezeExpressions - # def to_c(self, print_func): - # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) - # assert dtype.is_int() - # code = "( ({0}) % ({1}) == 0 ? ({dtype})({0}) / ({dtype})({1}) : ( ({dtype})({0}) / ({dtype})({1}) ) +1 )" - # return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) + if is_integer_sequence((arg1, arg2)): + return c_intdiv(arg1 + arg2 - 1, arg2) + + def _eval_op(self, arg1, arg2): + return self.eval(arg1, arg2) + + +# Deprecated functions. # noinspection PyPep8Naming -class div_floor(sp.Function): - """Integer division +class modulo_floor: + def __new__(cls, integer, divisor): + warnings.warn( + "`modulo_floor` is deprecated. Use `round_to_multiple_towards_zero` instead.", + DeprecationWarning, + ) + return round_to_multiple_towards_zero(integer, divisor) - Examples: - >>> div_floor(9, 4) - 2 - >>> div_floor(8, 4) - 2 - """ - nargs = 2 - is_integer = True +# noinspection PyPep8Naming +class modulo_ceil(sp.Function): + def __new__(cls, integer, divisor): + warnings.warn( + "`modulo_ceil` is deprecated. Use `ceil_to_multiple` instead.", + DeprecationWarning, + ) + return ceil_to_multiple(integer, divisor) + + +# noinspection PyPep8Naming +class div_floor(sp.Function): def __new__(cls, integer, divisor): - if is_integer_sequence((integer, divisor)): - return integer // divisor - else: - return super().__new__(cls, integer, divisor) - - # TODO: Implement this in FreezeExpressions - # def to_c(self, print_func): - # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) - # assert dtype.is_int() - # code = "(({dtype})({0}) / ({dtype})({1}))" - # return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) + warnings.warn( + "`div_floor` is deprecated. Use `int_div` instead.", + DeprecationWarning, + ) + return int_div(integer, divisor) diff --git a/src/pystencils/sympyextensions/pointers.py b/src/pystencils/sympyextensions/pointers.py index c69f9376dd31c8e0f7976c750e9af17308b7991e..2ebeba7c927a952d5ee498e7c007b1e03ed3760b 100644 --- a/src/pystencils/sympyextensions/pointers.py +++ b/src/pystencils/sympyextensions/pointers.py @@ -31,3 +31,14 @@ class AddressOf(sp.Function): return PsPointerType(arg_type, restrict=True, const=True) else: raise ValueError(f'pystencils supports only non void pointers. Current address_of type: {self.args[0]}') + + +class mem_acc(sp.Function): + """Memory access through a raw pointer with an offset. + + This function should be used to model offset memory accesses through raw pointers. + """ + + @classmethod + def eval(cls, ptr, offset): + return None diff --git a/src/pystencils/types/parsing.py b/src/pystencils/types/parsing.py index 5771eaca84413708c68c4f7941e07cbd63403e9e..8e7d27f58265c08461cba6b05373848112a6fee7 100644 --- a/src/pystencils/types/parsing.py +++ b/src/pystencils/types/parsing.py @@ -8,6 +8,7 @@ from .types import ( PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType, + PsBoolType, ) UserTypeSpec = str | type | np.dtype | PsType @@ -143,6 +144,9 @@ def parse_type_string(s: str) -> PsType: def parse_type_name(typename: str, const: bool): match typename: + case "bool": + return PsBoolType(const=const) + case "int" | "int64" | "int64_t": return PsSignedIntegerType(64, const=const) case "int32" | "int32_t": diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index 61e3d73fd5b66e06b48ea9de788b3fa1b51a7a61..d3d18720cf1ff3c4af14f6c276da52098adfbdd2 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -1,6 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import final, Any, Sequence +from typing import final, Any, Sequence, SupportsIndex from dataclasses import dataclass import numpy as np @@ -91,38 +91,65 @@ class PsPointerType(PsDereferencableType): def c_string(self) -> str: base_str = self._base_type.c_string() restrict_str = " RESTRICT" if self._restrict else "" - return f"{base_str} *{restrict_str} {self._const_string()}" + const_str = " const" if self.const else "" + return f"{base_str} *{restrict_str}{const_str}" def __repr__(self) -> str: return f"PsPointerType( {repr(self.base_type)}, const={self.const}, restrict={self.restrict} )" class PsArrayType(PsDereferencableType): - """C array type of known or unknown size.""" + """Multidimensional array of fixed shape. + + The element type of an array is never const; only the array itself can be. + If ``element_type`` is const, its constness will be removed. + """ def __init__( - self, base_type: PsType, length: int | None = None, const: bool = False + self, element_type: PsType, shape: SupportsIndex | Sequence[SupportsIndex], const: bool = False ): - self._length = length - super().__init__(base_type, const) + from operator import index + if isinstance(shape, SupportsIndex): + shape = (index(shape),) + else: + shape = tuple(index(s) for s in shape) + + if not shape or any(s <= 0 for s in shape): + raise ValueError(f"Invalid array shape: {shape}") + + if isinstance(element_type, PsArrayType): + raise ValueError("Element type of array cannot be another array.") + + element_type = deconstify(element_type) + + self._shape = shape + super().__init__(element_type, const) def __args__(self) -> tuple[Any, ...]: """ - >>> t = PsArrayType(PsBoolType(), 13) + >>> t = PsArrayType(PsBoolType(), (13, 42)) >>> t == PsArrayType(*t.__args__()) True """ - return (self._base_type, self._length) + return (self._base_type, self._shape) @property - def length(self) -> int | None: - return self._length + def shape(self) -> tuple[int, ...]: + """Shape of this array""" + return self._shape + + @property + def dim(self) -> int: + """Dimensionality of this array""" + return len(self._shape) def c_string(self) -> str: - return f"{self._base_type.c_string()} [{str(self._length) if self._length is not None else ''}]" + arr_brackets = "".join(f"[{s}]" for s in self._shape) + const = self._const_string() + return const + self._base_type.c_string() + arr_brackets def __repr__(self) -> str: - return f"PsArrayType(element_type={repr(self._base_type)}, size={self._length}, const={self._const})" + return f"PsArrayType(element_type={repr(self._base_type)}, shape={self._shape}, const={self._const})" class PsStructType(PsType): @@ -131,6 +158,8 @@ class PsStructType(PsType): A struct type is defined by its sequence of members. The struct may optionally have a name, although the code generator currently does not support named structs and treats them the same way as anonymous structs. + + Struct member types cannot be ``const``; if a ``const`` member type is passed, its constness will be removed. """ @dataclass(frozen=True) @@ -138,6 +167,10 @@ class PsStructType(PsType): name: str dtype: PsType + def __post_init__(self): + # Need to use object.__setattr__ because instances are frozen + object.__setattr__(self, "dtype", deconstify(self.dtype)) + @staticmethod def _canonical_members(members: Sequence[PsStructType.Member | tuple[str, PsType]]): return tuple( diff --git a/tests/frontend/test_sympyextensions.py b/tests/frontend/test_sympyextensions.py index 05c11996864073be98c6aaea51de10db3867dcfb..ad5d2513b4400db938b8372a83ef43cc9339b35d 100644 --- a/tests/frontend/test_sympyextensions.py +++ b/tests/frontend/test_sympyextensions.py @@ -3,6 +3,7 @@ import numpy as np import sympy as sp import pystencils +from pystencils import Assignment from pystencils.sympyextensions import replace_second_order_products from pystencils.sympyextensions import remove_higher_order_terms from pystencils.sympyextensions import complete_the_squares_in_exp @@ -13,10 +14,18 @@ from pystencils.sympyextensions import common_denominator from pystencils.sympyextensions import get_symmetric_part from pystencils.sympyextensions import scalar_product from pystencils.sympyextensions import kronecker_delta - -from pystencils import Assignment -from pystencils.sympyextensions.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt, - insert_fast_divisions, insert_fast_sqrts) +from pystencils.sympyextensions.fast_approximation import ( + fast_division, + fast_inv_sqrt, + fast_sqrt, + insert_fast_divisions, + insert_fast_sqrts, +) +from pystencils.sympyextensions.integer_functions import ( + round_to_multiple_towards_zero, + ceil_to_multiple, + div_ceil, +) def test_utility(): @@ -39,10 +48,10 @@ def test_utility(): def test_replace_second_order_products(): - x, y = sympy.symbols('x y') + x, y = sympy.symbols("x y") expr = 4 * x * y - expected_expr_positive = 2 * ((x + y) ** 2 - x ** 2 - y ** 2) - expected_expr_negative = 2 * (-(x - y) ** 2 + x ** 2 + y ** 2) + expected_expr_positive = 2 * ((x + y) ** 2 - x**2 - y**2) + expected_expr_negative = 2 * (-((x - y) ** 2) + x**2 + y**2) result = replace_second_order_products(expr, search_symbols=[x, y], positive=True) assert result == expected_expr_positive @@ -55,15 +64,17 @@ def test_replace_second_order_products(): result = replace_second_order_products(expr, search_symbols=[x, y], positive=None) assert result == expected_expr_positive - a = [Assignment(sympy.symbols('z'), x + y)] - replace_second_order_products(expr, search_symbols=[x, y], positive=True, replace_mixed=a) + a = [Assignment(sympy.symbols("z"), x + y)] + replace_second_order_products( + expr, search_symbols=[x, y], positive=True, replace_mixed=a + ) assert len(a) == 2 assert replace_second_order_products(4 + y, search_symbols=[x, y]) == y + 4 def test_remove_higher_order_terms(): - x, y = sympy.symbols('x y') + x, y = sympy.symbols("x y") expr = sympy.Mul(x, y) @@ -81,19 +92,19 @@ def test_remove_higher_order_terms(): def test_complete_the_squares_in_exp(): - a, b, c, s, n = sympy.symbols('a b c s n') - expr = a * s ** 2 + b * s + c + a, b, c, s, n = sympy.symbols("a b c s n") + expr = a * s**2 + b * s + c result = complete_the_squares_in_exp(expr, symbols_to_complete=[s]) assert result == expr - expr = sympy.exp(a * s ** 2 + b * s + c) - expected_result = sympy.exp(a*s**2 + c - b**2 / (4*a)) + expr = sympy.exp(a * s**2 + b * s + c) + expected_result = sympy.exp(a * s**2 + c - b**2 / (4 * a)) result = complete_the_squares_in_exp(expr, symbols_to_complete=[s]) assert result == expected_result def test_extract_most_common_factor(): - x, y = sympy.symbols('x y') + x, y = sympy.symbols("x y") expr = 1 / (x + y) + 3 / (x + y) + 3 / (x + y) most_common_factor = extract_most_common_factor(expr) @@ -115,98 +126,98 @@ def test_extract_most_common_factor(): def test_count_operations(): - x, y, z = sympy.symbols('x y z') - expr = 1/x + y * sympy.sqrt(z) + x, y, z = sympy.symbols("x y z") + expr = 1 / x + y * sympy.sqrt(z) ops = count_operations(expr, only_type=None) - assert ops['adds'] == 1 - assert ops['muls'] == 1 - assert ops['divs'] == 1 - assert ops['sqrts'] == 1 + assert ops["adds"] == 1 + assert ops["muls"] == 1 + assert ops["divs"] == 1 + assert ops["sqrts"] == 1 expr = 1 / sympy.sqrt(z) ops = count_operations(expr, only_type=None) - assert ops['adds'] == 0 - assert ops['muls'] == 0 - assert ops['divs'] == 1 - assert ops['sqrts'] == 1 + assert ops["adds"] == 0 + assert ops["muls"] == 0 + assert ops["divs"] == 1 + assert ops["sqrts"] == 1 expr = sympy.Rel(1 / sympy.sqrt(z), 5) ops = count_operations(expr, only_type=None) - assert ops['adds'] == 0 - assert ops['muls'] == 0 - assert ops['divs'] == 1 - assert ops['sqrts'] == 1 + assert ops["adds"] == 0 + assert ops["muls"] == 0 + assert ops["divs"] == 1 + assert ops["sqrts"] == 1 expr = sympy.sqrt(x + y) expr = insert_fast_sqrts(expr).atoms(fast_sqrt) ops = count_operations(*expr, only_type=None) - assert ops['fast_sqrts'] == 1 + assert ops["fast_sqrts"] == 1 expr = sympy.sqrt(x / y) expr = insert_fast_divisions(expr).atoms(fast_division) ops = count_operations(*expr, only_type=None) - assert ops['fast_div'] == 1 + assert ops["fast_div"] == 1 - expr = pystencils.Assignment(sympy.Symbol('tmp'), 3 / sympy.sqrt(x + y)) + expr = pystencils.Assignment(sympy.Symbol("tmp"), 3 / sympy.sqrt(x + y)) expr = insert_fast_sqrts(expr).atoms(fast_inv_sqrt) ops = count_operations(*expr, only_type=None) - assert ops['fast_inv_sqrts'] == 1 + assert ops["fast_inv_sqrts"] == 1 expr = sympy.Piecewise((1.0, x > 0), (0.0, True)) + y * z ops = count_operations(expr, only_type=None) - assert ops['adds'] == 1 + assert ops["adds"] == 1 - expr = sympy.Pow(1/x + y * sympy.sqrt(z), 100) + expr = sympy.Pow(1 / x + y * sympy.sqrt(z), 100) ops = count_operations(expr, only_type=None) - assert ops['adds'] == 1 - assert ops['muls'] == 99 - assert ops['divs'] == 1 - assert ops['sqrts'] == 1 + assert ops["adds"] == 1 + assert ops["muls"] == 99 + assert ops["divs"] == 1 + assert ops["sqrts"] == 1 expr = x / y ops = count_operations(expr, only_type=None) - assert ops['divs'] == 1 + assert ops["divs"] == 1 expr = x + z / y + z ops = count_operations(expr, only_type=None) - assert ops['adds'] == 2 - assert ops['divs'] == 1 + assert ops["adds"] == 2 + assert ops["divs"] == 1 - expr = sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)) + expr = sp.UnevaluatedExpr(sp.Mul(*[x] * 100, evaluate=False)) ops = count_operations(expr, only_type=None) - assert ops['muls'] == 99 + assert ops["muls"] == 99 - expr = 1 / sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)) + expr = 1 / sp.UnevaluatedExpr(sp.Mul(*[x] * 100, evaluate=False)) ops = count_operations(expr, only_type=None) - assert ops['divs'] == 1 - assert ops['muls'] == 99 + assert ops["divs"] == 1 + assert ops["muls"] == 99 - expr = (y + z) / sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)) + expr = (y + z) / sp.UnevaluatedExpr(sp.Mul(*[x] * 100, evaluate=False)) ops = count_operations(expr, only_type=None) - assert ops['adds'] == 1 - assert ops['divs'] == 1 - assert ops['muls'] == 99 + assert ops["adds"] == 1 + assert ops["divs"] == 1 + assert ops["muls"] == 99 def test_common_denominator(): - x = sympy.symbols('x') + x = sympy.symbols("x") expr = sympy.Rational(1, 2) + x * sympy.Rational(2, 3) cm = common_denominator(expr) assert cm == 6 def test_get_symmetric_part(): - x, y, z = sympy.symbols('x y z') - expr = x / 9 - y ** 2 / 6 + z ** 2 / 3 + z / 3 - expected_result = x / 9 - y ** 2 / 6 + z ** 2 / 3 - sym_part = get_symmetric_part(expr, sympy.symbols(f'y z')) + x, y, z = sympy.symbols("x y z") + expr = x / 9 - y**2 / 6 + z**2 / 3 + z / 3 + expected_result = x / 9 - y**2 / 6 + z**2 / 3 + sym_part = get_symmetric_part(expr, sympy.symbols(f"y z")) assert sym_part == expected_result def test_simplify_by_equality(): - x, y, z = sp.symbols('x, y, z') - p, q = sp.symbols('p, q') + x, y, z = sp.symbols("x, y, z") + p, q = sp.symbols("p, q") # Let x = y + z expr = x * p - y * p + z * q @@ -219,9 +230,24 @@ def test_simplify_by_equality(): expr = x * (y + z) - y * z expr = simplify_by_equality(expr, x, y, z) - assert expr == x*y + z**2 + assert expr == x * y + z**2 # Let x = y + 2 expr = x * p - 2 * p expr = simplify_by_equality(expr, x, y, 2) assert expr == y * p + + +def test_integer_functions(): + assert round_to_multiple_towards_zero(9, 4) == 8 + assert round_to_multiple_towards_zero(11, -4) == 8 + assert round_to_multiple_towards_zero(12, 4) == 12 + assert round_to_multiple_towards_zero(-9, 4) == -8 + assert round_to_multiple_towards_zero(-9, -4) == -8 + + assert ceil_to_multiple(9, 4) == 12 + assert ceil_to_multiple(11, 4) == 12 + assert ceil_to_multiple(12, 4) == 12 + + assert div_ceil(9, 4) == 3 + assert div_ceil(8, 4) == 2 diff --git a/tests/nbackend/kernelcreation/test_context.py b/tests/nbackend/kernelcreation/test_context.py index 9701013b000c26446649c82e6c96c5a64861f76c..384fc93158a9f7aa7ff9911b20382c0b79ed36ee 100644 --- a/tests/nbackend/kernelcreation/test_context.py +++ b/tests/nbackend/kernelcreation/test_context.py @@ -5,6 +5,8 @@ from pystencils import Field, TypedSymbol, FieldType, DynamicType from pystencils.backend.kernelcreation import KernelCreationContext from pystencils.backend.constants import PsConstant +from pystencils.backend.memory import PsSymbol +from pystencils.backend.properties import FieldShape, FieldStride from pystencils.backend.exceptions import KernelConstraintsError from pystencils.types.quick import SInt, Fp from pystencils.types import deconstify @@ -14,7 +16,7 @@ def test_field_arrays(): ctx = KernelCreationContext(index_dtype=SInt(16)) f = Field.create_generic("f", 3, Fp(32)) - f_arr = ctx.get_array(f) + f_arr = ctx.get_buffer(f) assert f_arr.element_type == f.dtype == Fp(32) assert len(f_arr.shape) == len(f.shape) + 1 == 4 @@ -23,9 +25,17 @@ def test_field_arrays(): assert f_arr.index_type == ctx.index_dtype == SInt(16) assert f_arr.shape[0].dtype == ctx.index_dtype == SInt(16) + for i, s in enumerate(f_arr.shape[:1]): + assert isinstance(s, PsSymbol) + assert FieldShape(f, i) in s.properties + + for i, s in enumerate(f_arr.strides[:1]): + assert isinstance(s, PsSymbol) + assert FieldStride(f, i) in s.properties + g = Field.create_generic("g", 3, index_shape=(2, 4), dtype=Fp(16)) - g_arr = ctx.get_array(g) - + g_arr = ctx.get_buffer(g) + assert g_arr.element_type == g.dtype == Fp(16) assert len(g_arr.shape) == len(g.spatial_shape) + len(g.index_shape) == 5 assert isinstance(g_arr.shape[3], PsConstant) and g_arr.shape[3].value == 2 @@ -39,26 +49,23 @@ def test_field_arrays(): FieldType.GENERIC, Fp(32), (0, 1), - ( - TypedSymbol("nx", SInt(32)), - TypedSymbol("ny", SInt(32)), - 1 - ), - ( - TypedSymbol("sx", SInt(32)), - TypedSymbol("sy", SInt(32)), - 1 - ) - ) - - h_arr = ctx.get_array(h) + (TypedSymbol("nx", SInt(32)), TypedSymbol("ny", SInt(32)), 1), + (TypedSymbol("sx", SInt(32)), TypedSymbol("sy", SInt(32)), 1), + ) + + h_arr = ctx.get_buffer(h) assert h_arr.index_type == SInt(32) - + for s in chain(h_arr.shape, h_arr.strides): assert deconstify(s.get_dtype()) == SInt(32) - assert [s.name for s in chain(h_arr.shape[:2], h_arr.strides[:2])] == ["nx", "ny", "sx", "sy"] + assert [s.name for s in chain(h_arr.shape[:2], h_arr.strides[:2])] == [ + "nx", + "ny", + "sx", + "sy", + ] def test_invalid_fields(): @@ -70,11 +77,11 @@ def test_invalid_fields(): Fp(32), (0,), (TypedSymbol("nx", SInt(32)),), - (TypedSymbol("sx", SInt(64)),) + (TypedSymbol("sx", SInt(64)),), ) - + with pytest.raises(KernelConstraintsError): - _ = ctx.get_array(h) + _ = ctx.get_buffer(h) h = Field( "h", @@ -82,11 +89,11 @@ def test_invalid_fields(): Fp(32), (0,), (TypedSymbol("nx", Fp(32)),), - (TypedSymbol("sx", Fp(32)),) + (TypedSymbol("sx", Fp(32)),), ) - + with pytest.raises(KernelConstraintsError): - _ = ctx.get_array(h) + _ = ctx.get_buffer(h) h = Field( "h", @@ -94,8 +101,39 @@ def test_invalid_fields(): Fp(32), (0,), (TypedSymbol("nx", DynamicType.NUMERIC_TYPE),), - (TypedSymbol("sx", DynamicType.NUMERIC_TYPE),) + (TypedSymbol("sx", DynamicType.NUMERIC_TYPE),), ) - + with pytest.raises(KernelConstraintsError): - _ = ctx.get_array(h) + _ = ctx.get_buffer(h) + + +def test_duplicate_fields(): + f = Field.create_generic("f", 3) + g = f.new_field_with_different_name("g") + + # f and g have the same indexing symbols + assert f.shape == g.shape + assert f.strides == g.strides + + ctx = KernelCreationContext() + + f_buf = ctx.get_buffer(f) + g_buf = ctx.get_buffer(g) + + for sf, sg in zip(chain(f_buf.shape, f_buf.strides), chain(g_buf.shape, g_buf.strides)): + # Must be the same + assert sf == sg + + for i, s in enumerate(f_buf.shape[:-1]): + assert isinstance(s, PsSymbol) + assert FieldShape(f, i) in s.properties + assert FieldShape(g, i) in s.properties + + for i, s in enumerate(f_buf.strides[:-1]): + assert isinstance(s, PsSymbol) + assert FieldStride(f, i) in s.properties + assert FieldStride(g, i) in s.properties + + # Base pointers must be different, though! + assert f_buf.base_pointer != g_buf.base_pointer diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 270c8f44a6b2a64500446e2284866436069cc704..ce4f6178511aaad913819eafb59d9ccae42ee992 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -1,15 +1,23 @@ import sympy as sp import pytest -from pystencils import Assignment, fields, create_type, create_numeric_type, TypedSymbol, DynamicType +from pystencils import ( + Assignment, + fields, + create_type, + create_numeric_type, + TypedSymbol, + DynamicType, +) from pystencils.sympyextensions import CastFunc +from pystencils.sympyextensions.pointers import mem_acc from pystencils.backend.ast.structural import ( PsAssignment, PsDeclaration, ) from pystencils.backend.ast.expressions import ( - PsArrayAccess, + PsBufferAcc, PsBitwiseAnd, PsBitwiseOr, PsBitwiseXor, @@ -29,6 +37,13 @@ from pystencils.backend.ast.expressions import ( PsGe, PsCall, PsCast, + PsConstantExpr, + PsAdd, + PsMul, + PsSub, + PsArrayInitList, + PsSubscript, + PsMemAcc, ) from pystencils.backend.constants import PsConstant from pystencils.backend.functions import PsMathFunction, MathFunctions @@ -47,6 +62,9 @@ from pystencils.sympyextensions.integer_functions import ( bitwise_xor, int_div, int_power_of_2, + round_to_multiple_towards_zero, + ceil_to_multiple, + div_ceil, ) @@ -88,22 +106,20 @@ def test_freeze_fields(): f, g = fields("f, g : [1D]") asm = Assignment(f.center(0), g.center(0)) - f_arr = ctx.get_array(f) - g_arr = ctx.get_array(g) + f_arr = ctx.get_buffer(f) + g_arr = ctx.get_buffer(g) fasm = freeze(asm) zero = PsExpression.make(PsConstant(0)) - lhs = PsArrayAccess( + lhs = PsBufferAcc( f_arr.base_pointer, - (PsExpression.make(counter) + zero) * PsExpression.make(f_arr.strides[0]) - + zero * one, + (PsExpression.make(counter) + zero, zero) ) - rhs = PsArrayAccess( + rhs = PsBufferAcc( g_arr.base_pointer, - (PsExpression.make(counter) + zero) * PsExpression.make(g_arr.strides[0]) - + zero * one, + (PsExpression.make(counter) + zero, zero) ) should = PsAssignment(lhs, rhs) @@ -142,10 +158,13 @@ def test_freeze_integer_functions(): z2 = PsExpression.make(ctx.get_symbol("z", ctx.index_dtype)) x, y, z = sp.symbols("x, y, z") + one = PsExpression.make(PsConstant(1)) asms = [ Assignment(z, int_div(x, y)), Assignment(z, int_power_of_2(x, y)), - # Assignment(z, modulo_floor(x, y)), + Assignment(z, round_to_multiple_towards_zero(x, y)), + Assignment(z, ceil_to_multiple(x, y)), + Assignment(z, div_ceil(x, y)), ] fasms = [freeze(asm) for asm in asms] @@ -153,7 +172,9 @@ def test_freeze_integer_functions(): should = [ PsDeclaration(z2, PsIntDiv(x2, y2)), PsDeclaration(z2, PsLeftShift(PsExpression.make(PsConstant(1)), x2)), - # PsDeclaration(z2, PsMul(PsIntDiv(x2, y2), y2)), + PsDeclaration(z2, PsIntDiv(x2, y2) * y2), + PsDeclaration(z2, PsIntDiv(x2 + y2 - one, y2) * y2), + PsDeclaration(z2, PsIntDiv(x2 + y2 - one, y2)), ] for fasm, correct in zip(fasms, should): @@ -276,11 +297,11 @@ def test_dynamic_types(): p, q = [TypedSymbol(n, DynamicType.INDEX_TYPE) for n in "pq"] expr = freeze(x + y) - + assert ctx.get_symbol("x").dtype == ctx.default_dtype assert ctx.get_symbol("y").dtype == ctx.default_dtype - expr = freeze(p - q) + expr = freeze(p - q) assert ctx.get_symbol("p").dtype == ctx.index_dtype assert ctx.get_symbol("q").dtype == ctx.index_dtype @@ -305,3 +326,129 @@ def test_cast_func(): expr = freeze(CastFunc.as_index(z)) assert expr.structurally_equal(PsCast(ctx.index_dtype, z2)) + + expr = freeze(CastFunc(42, create_type("int16"))) + assert expr.structurally_equal(PsConstantExpr(PsConstant(42, create_type("int16")))) + + +def test_add_sub(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + x = sp.Symbol("x") + y = sp.Symbol("y", negative=True) + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + + two = PsExpression.make(PsConstant(2)) + minus_two = PsExpression.make(PsConstant(-2)) + + expr = freeze(x + y) + assert expr.structurally_equal(PsAdd(x2, y2)) + + expr = freeze(x - y) + assert expr.structurally_equal(PsSub(x2, y2)) + + expr = freeze(x + 2 * y) + assert expr.structurally_equal(PsAdd(x2, PsMul(two, y2))) + + expr = freeze(x - 2 * y) + assert expr.structurally_equal(PsAdd(x2, PsMul(minus_two, y2))) + + +def test_tuple_array_literals(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + x, y, z = sp.symbols("x, y, z") + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + z2 = PsExpression.make(ctx.get_symbol("z")) + + one = PsExpression.make(PsConstant(1)) + three = PsExpression.make(PsConstant(3)) + four = PsExpression.make(PsConstant(4)) + + arr_literal = freeze(sp.Tuple(3 + y, z, z / 4)) + assert arr_literal.structurally_equal( + PsArrayInitList([three + y2, z2, one / four * z2]) + ) + + +def test_nested_tuples(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + def f(n): + return freeze(sp.sympify(n)) + + shape = (2, 3, 2) + symb_arr = sp.Tuple(((1, 2), (3, 4), (5, 6)), ((5, 6), (7, 8), (9, 10))) + arr_literal = freeze(symb_arr) + + assert isinstance(arr_literal, PsArrayInitList) + assert arr_literal.shape == shape + + assert arr_literal.structurally_equal( + PsArrayInitList( + [ + ((f(1), f(2)), (f(3), f(4)), (f(5), f(6))), + ((f(5), f(6)), (f(7), f(8)), (f(9), f(10))), + ] + ) + ) + + +def test_invalid_arrays(): + ctx = KernelCreationContext() + + freeze = FreezeExpressions(ctx) + # invalid: nonuniform nesting depth + symb_arr = sp.Tuple((3, 32), 14) + with pytest.raises(FreezeError): + _ = freeze(symb_arr) + + # invalid: nonuniform sub-array length + symb_arr = sp.Tuple((3, 32), (14, -7, 3)) + with pytest.raises(FreezeError): + _ = freeze(symb_arr) + + # invalid: empty subarray + symb_arr = sp.Tuple((), (0, -9)) + with pytest.raises(FreezeError): + _ = freeze(symb_arr) + + # invalid: all subarrays empty + symb_arr = sp.Tuple((), ()) + with pytest.raises(FreezeError): + _ = freeze(symb_arr) + + +def test_memory_access(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + ptr = sp.Symbol("ptr") + expr = freeze(mem_acc(ptr, 31)) + + assert isinstance(expr, PsMemAcc) + assert expr.pointer.structurally_equal(PsExpression.make(ctx.get_symbol("ptr"))) + assert expr.offset.structurally_equal(PsExpression.make(PsConstant(31))) + + +def test_indexed(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + x, y, z = sp.symbols("x, y, z") + a = sp.IndexedBase("a") + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + z2 = PsExpression.make(ctx.get_symbol("z")) + a2 = PsExpression.make(ctx.get_symbol("a")) + + expr = freeze(a[x, y, z]) + assert expr.structurally_equal(PsSubscript(a2, (x2, y2, z2))) diff --git a/tests/nbackend/kernelcreation/test_iteration_space.py b/tests/nbackend/kernelcreation/test_iteration_space.py index 8ff678fbad398af18b71f2f30145ff34ab269685..5d56abd2b818fa74fbd48aac0216d472112f8c64 100644 --- a/tests/nbackend/kernelcreation/test_iteration_space.py +++ b/tests/nbackend/kernelcreation/test_iteration_space.py @@ -23,7 +23,7 @@ def test_slices_over_field(): islice = (slice(1, -1, 1), slice(3, -3, 3), slice(0, None, 1)) ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) - archetype_arr = ctx.get_array(archetype_field) + archetype_arr = ctx.get_buffer(archetype_field) dims = ispace.dimensions @@ -58,7 +58,7 @@ def test_slices_with_fixed_size_field(): islice = (slice(1, -1, 1), slice(3, -3, 3), slice(0, None, 1)) ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) - archetype_arr = ctx.get_array(archetype_field) + archetype_arr = ctx.get_buffer(archetype_field) dims = ispace.dimensions @@ -87,7 +87,7 @@ def test_singular_slice_over_field(): archetype_field = Field.create_generic("f", spatial_dimensions=2, layout="fzyx") ctx.add_field(archetype_field) - archetype_arr = ctx.get_array(archetype_field) + archetype_arr = ctx.get_buffer(archetype_field) islice = (4, -3) ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) @@ -113,7 +113,7 @@ def test_slices_with_negative_start(): archetype_field = Field.create_generic("f", spatial_dimensions=2, layout="fzyx") ctx.add_field(archetype_field) - archetype_arr = ctx.get_array(archetype_field) + archetype_arr = ctx.get_buffer(archetype_field) islice = (slice(-3, -1, 1), slice(-4, None, 1)) ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 2ebfa2ec8ec8cf542aa55bbccd770ef0a7e9f5d4..988fa4bb8b10c2c243abfd3a171657ad6bf5e418 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -5,6 +5,7 @@ import numpy as np from typing import cast from pystencils import Assignment, TypedSymbol, Field, FieldType, AddAugmentedAssignment +from pystencils.sympyextensions.pointers import mem_acc from pystencils.backend.ast.structural import ( PsDeclaration, @@ -14,6 +15,9 @@ from pystencils.backend.ast.structural import ( PsBlock, ) from pystencils.backend.ast.expressions import ( + PsAddressOf, + PsArrayInitList, + PsCast, PsConstantExpr, PsSymbolExpr, PsSubscript, @@ -29,11 +33,12 @@ from pystencils.backend.ast.expressions import ( PsLt, PsCall, PsTernary, + PsMemAcc ) from pystencils.backend.constants import PsConstant from pystencils.backend.functions import CFunction from pystencils.types import constify, create_type, create_numeric_type -from pystencils.types.quick import Fp, Int, Bool, Arr +from pystencils.types.quick import Fp, Int, Bool, Arr, Ptr from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError @@ -61,7 +66,7 @@ def test_typify_simple(): assert isinstance(fasm, PsDeclaration) def check(expr): - assert expr.dtype == constify(ctx.default_dtype) + assert expr.dtype == ctx.default_dtype match expr: case PsConstantExpr(cs): assert cs.value == 2 @@ -79,41 +84,6 @@ def test_typify_simple(): check(fasm.rhs) -def test_rhs_constness(): - default_type = Fp(32) - ctx = KernelCreationContext(default_dtype=default_type) - - freeze = FreezeExpressions(ctx) - typify = Typifier(ctx) - - f = Field.create_generic( - "f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM - ) - f_const = Field.create_generic( - "f_const", - 1, - index_shape=(1,), - dtype=constify(default_type), - field_type=FieldType.CUSTOM, - ) - - x, y, z = sp.symbols("x, y, z") - - # Right-hand sides should always get const types - asm = typify(freeze(Assignment(x, f.absolute_access([0], [0])))) - assert asm.rhs.get_dtype().const - - asm = typify( - freeze( - Assignment( - f.absolute_access([0], [0]), - f.absolute_access([0], [0]) * f_const.absolute_access([0], [0]) * x + y, - ) - ) - ) - assert asm.rhs.get_dtype().const - - def test_lhs_constness(): default_type = Fp(32) ctx = KernelCreationContext(default_dtype=default_type) @@ -133,7 +103,7 @@ def test_lhs_constness(): x, y, z = sp.symbols("x, y, z") - # Assignment RHS may not be const + # Can assign to non-const LHS asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y))) assert not asm.lhs.get_dtype().const @@ -155,7 +125,7 @@ def test_lhs_constness(): q = ctx.get_symbol("q", Fp(32, const=True)) ast = PsDeclaration(PsExpression.make(q), PsExpression.make(q)) ast = typify(ast) - assert ast.lhs.dtype == Fp(32, const=True) + assert ast.lhs.dtype == Fp(32) ast = PsAssignment(PsExpression.make(q), PsExpression.make(q)) with pytest.raises(TypificationError): @@ -197,7 +167,7 @@ def test_default_typing(): expr = typify(expr) def check(expr): - assert expr.dtype == constify(ctx.default_dtype) + assert expr.dtype == ctx.default_dtype match expr: case PsConstantExpr(cs): assert cs.value in (2, 3, -4) @@ -214,6 +184,141 @@ def test_default_typing(): check(expr) +def test_inline_arrays_1d(): + ctx = KernelCreationContext(default_dtype=Fp(32)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x = sp.Symbol("x") + y = TypedSymbol("y", Fp(16)) + idx = TypedSymbol("idx", Int(32)) + + arr: PsArrayInitList = cast(PsArrayInitList, freeze(sp.Tuple(1, 2, 3, 4))) + decl = PsDeclaration(freeze(x), freeze(y) + PsSubscript(arr, (freeze(idx),))) + # The array elements should learn their type from the context, which gets it from `y` + + decl = typify(decl) + assert decl.lhs.dtype == Fp(16) + assert decl.rhs.dtype == Fp(16) + + assert arr.dtype == Arr(Fp(16), (4,)) + for item in arr.items: + assert item.dtype == Fp(16) + + +def test_inline_arrays_3d(): + ctx = KernelCreationContext(default_dtype=Fp(32)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x = sp.Symbol("x") + y = TypedSymbol("y", Fp(16)) + idx = [TypedSymbol(f"idx_{i}", Int(32)) for i in range(3)] + + arr: PsArrayInitList = freeze( + sp.Tuple(((1, 2), (3, 4), (5, 6)), ((5, 6), (7, 8), (9, 10))) + ) + decl = PsDeclaration( + freeze(x), + freeze(y) + PsSubscript(arr, (freeze(idx[0]), freeze(idx[1]), freeze(idx[2]))), + ) + # The array elements should learn their type from the context, which gets it from `y` + + decl = typify(decl) + assert decl.lhs.dtype == Fp(16) + assert decl.rhs.dtype == Fp(16) + + assert arr.dtype == Arr(Fp(16), (2, 3, 2)) + assert arr.shape == (2, 3, 2) + for item in arr.items: + assert item.dtype == Fp(16) + + +def test_array_subscript(): + ctx = KernelCreationContext(default_dtype=Fp(16)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + arr = sp.IndexedBase(TypedSymbol("arr", Arr(Fp(32), (16,)))) + expr = freeze(arr[3]) + expr = typify(expr) + + assert expr.dtype == Fp(32) + + arr = sp.IndexedBase(TypedSymbol("arr2", Arr(Fp(32), (7, 31)))) + expr = freeze(arr[3, 5]) + expr = typify(expr) + + assert expr.dtype == Fp(32) + + +def test_invalid_subscript(): + ctx = KernelCreationContext(default_dtype=Fp(16)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + non_arr = sp.IndexedBase(TypedSymbol("non_arr", Int(64))) + expr = freeze(non_arr[3]) + + with pytest.raises(TypificationError): + expr = typify(expr) + + wrong_shape_arr = sp.IndexedBase( + TypedSymbol("wrong_shape_arr", Arr(Fp(32), (7, 31, 5))) + ) + expr = freeze(wrong_shape_arr[3, 5]) + + with pytest.raises(TypificationError): + expr = typify(expr) + + # raw pointers are not arrays, cannot enter subscript + ptr = sp.IndexedBase( + TypedSymbol("ptr", Ptr(Int(16))) + ) + expr = freeze(ptr[37]) + + with pytest.raises(TypificationError): + expr = typify(expr) + + +def test_mem_acc(): + ctx = KernelCreationContext(default_dtype=Fp(16)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + ptr = TypedSymbol("ptr", Ptr(Int(64))) + idx = TypedSymbol("idx", Int(32)) + + expr = freeze(mem_acc(ptr, idx)) + expr = typify(expr) + + assert isinstance(expr, PsMemAcc) + assert expr.dtype == Int(64) + assert expr.offset.dtype == Int(32) + + +def test_invalid_mem_acc(): + ctx = KernelCreationContext(default_dtype=Fp(16)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + non_ptr = TypedSymbol("non_ptr", Int(64)) + idx = TypedSymbol("idx", Int(32)) + + expr = freeze(mem_acc(non_ptr, idx)) + + with pytest.raises(TypificationError): + _ = typify(expr) + + arr = TypedSymbol("arr", Arr(Int(64), (31,))) + idx = TypedSymbol("idx", Int(32)) + + expr = freeze(mem_acc(arr, idx)) + + with pytest.raises(TypificationError): + _ = typify(expr) + + def test_lhs_inference(): ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64)) freeze = FreezeExpressions(ctx) @@ -229,13 +334,13 @@ def test_lhs_inference(): fasm = typify(freeze(asm)) assert ctx.get_symbol("x").dtype == Fp(32) - assert fasm.lhs.dtype == constify(Fp(32)) + assert fasm.lhs.dtype == Fp(32) asm = Assignment(y, 3 - w) fasm = typify(freeze(asm)) assert ctx.get_symbol("y").dtype == Fp(16) - assert fasm.lhs.dtype == constify(Fp(16)) + assert fasm.lhs.dtype == Fp(16) fasm = PsAssignment(PsExpression.make(ctx.get_symbol("z")), freeze(3 - w)) fasm = typify(fasm) @@ -249,8 +354,48 @@ def test_lhs_inference(): fasm = typify(fasm) assert ctx.get_symbol("r").dtype == Bool() - assert fasm.lhs.dtype == constify(Bool()) - assert fasm.rhs.dtype == constify(Bool()) + assert fasm.lhs.dtype == Bool() + assert fasm.rhs.dtype == Bool() + + +def test_array_declarations(): + ctx = KernelCreationContext(default_dtype=Fp(32)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x, y, z = sp.symbols("x, y, z") + + # Array type fallback to default + arr1 = sp.Symbol("arr1") + decl = freeze(Assignment(arr1, sp.Tuple(1, 2, 3, 4))) + decl = typify(decl) + + assert ctx.get_symbol("arr1").dtype == Arr(Fp(32), (4,)) + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(32), (4,)) + + # Array type determined by default-typed symbol + arr2 = sp.Symbol("arr2") + decl = freeze(Assignment(arr2, sp.Tuple((x, y, -7), (3, -2, 51)))) + decl = typify(decl) + + assert ctx.get_symbol("arr2").dtype == Arr(Fp(32), (2, 3)) + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(32), (2, 3)) + + # Array type determined by pre-typed symbol + q = TypedSymbol("q", Fp(16)) + arr3 = sp.Symbol("arr3") + decl = freeze(Assignment(arr3, sp.Tuple((q, 2), (-q, 0.123)))) + decl = typify(decl) + + assert ctx.get_symbol("arr3").dtype == Arr(Fp(16), (2, 2)) + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(16), (2, 2)) + + # Array type determined by LHS symbol + arr4 = TypedSymbol("arr4", Arr(Int(16), 4)) + decl = freeze(Assignment(arr4, sp.Tuple(11, 1, 4, 2))) + decl = typify(decl) + + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Int(16), 4) def test_erronous_typing(): @@ -289,17 +434,17 @@ def test_invalid_indices(): ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64)) typify = Typifier(ctx) - arr = PsExpression.make(ctx.get_symbol("arr", Arr(Fp(64)))) + arr = PsExpression.make(ctx.get_symbol("arr", Arr(Fp(64), (61,)))) x, y, z = [PsExpression.make(ctx.get_symbol(x)) for x in "xyz"] # Using default-typed symbols as array indices is illegal when the default type is a float - fasm = PsAssignment(PsSubscript(arr, x + y), z) + fasm = PsAssignment(PsSubscript(arr, (x + y,)), z) with pytest.raises(TypificationError): typify(fasm) - fasm = PsAssignment(z, PsSubscript(arr, x + y)) + fasm = PsAssignment(z, PsSubscript(arr, (x + y,))) with pytest.raises(TypificationError): typify(fasm) @@ -391,7 +536,7 @@ def test_typify_bools_and_relations(): expr = PsAnd(PsEq(x, y), PsAnd(true, PsNot(PsOr(p, q)))) expr = typify(expr) - assert expr.dtype == Bool(const=True) + assert expr.dtype == Bool() def test_bool_in_numerical_context(): @@ -415,7 +560,7 @@ def test_typify_conditionals(rel): cond = PsConditional(rel(x, y), PsBlock([])) cond = typify(cond) - assert cond.condition.dtype == Bool(const=True) + assert cond.condition.dtype == Bool() def test_invalid_conditions(): @@ -444,11 +589,11 @@ def test_typify_ternary(): expr = PsTernary(p, x, y) expr = typify(expr) - assert expr.dtype == Fp(32, const=True) + assert expr.dtype == Fp(32) expr = PsTernary(PsAnd(p, q), a, b + a) expr = typify(expr) - assert expr.dtype == Int(32, const=True) + assert expr.dtype == Int(32) expr = PsTernary(PsAnd(p, q), a, x) with pytest.raises(TypificationError): @@ -472,9 +617,25 @@ def test_cfunction(): result = typify(PsCall(threeway, [x, y])) - assert result.get_dtype() == Int(32, const=True) - assert result.args[0].get_dtype() == Fp(32, const=True) - assert result.args[1].get_dtype() == Fp(32, const=True) + assert result.get_dtype() == Int(32) + assert result.args[0].get_dtype() == Fp(32) + assert result.args[1].get_dtype() == Fp(32) with pytest.raises(TypificationError): _ = typify(PsCall(threeway, (x, p))) + + +def test_inference_fails(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + + x = PsExpression.make(PsConstant(42)) + + with pytest.raises(TypificationError): + typify(PsEq(x, x)) + + with pytest.raises(TypificationError): + typify(PsArrayInitList([x])) + + with pytest.raises(TypificationError): + typify(PsCast(ctx.default_dtype, x)) diff --git a/tests/nbackend/test_ast.py b/tests/nbackend/test_ast.py index 09c63a5572f7873648712dbd3279d16f3c342458..2408b8d867038a0f2fd5c4d8a5f22bc82312c701 100644 --- a/tests/nbackend/test_ast.py +++ b/tests/nbackend/test_ast.py @@ -1,14 +1,22 @@ -from pystencils.backend.symbols import PsSymbol +import pytest + +from pystencils import create_type +from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, Typifier +from pystencils.backend.memory import PsSymbol, BufferBasePtr from pystencils.backend.constants import PsConstant from pystencils.backend.ast.expressions import ( PsExpression, PsCast, - PsDeref, + PsMemAcc, + PsArrayInitList, PsSubscript, + PsBufferAcc, + PsSymbolExpr, ) from pystencils.backend.ast.structural import ( PsStatement, PsAssignment, + PsDeclaration, PsBlock, PsConditional, PsComment, @@ -19,15 +27,25 @@ from pystencils.types.quick import Fp, Ptr def test_cloning(): - x, y, z = [PsExpression.make(PsSymbol(name)) for name in "xyz"] + ctx = KernelCreationContext() + typify = Typifier(ctx) + + x, y, z, m = [PsExpression.make(ctx.get_symbol(name)) for name in "xyzm"] + q = PsExpression.make(ctx.get_symbol("q", create_type("bool"))) + a, b, c = [PsExpression.make(ctx.get_symbol(name, ctx.index_dtype)) for name in "abc"] c1 = PsExpression.make(PsConstant(3.0)) c2 = PsExpression.make(PsConstant(-1.0)) - one = PsExpression.make(PsConstant(1)) + one_f = PsExpression.make(PsConstant(1.0)) + one_i = PsExpression.make(PsConstant(1)) def check(orig, clone): assert not (orig is clone) assert type(orig) is type(clone) assert orig.structurally_equal(clone) + + if isinstance(orig, PsExpression): + # Regression: Expression data types used to not be cloned + assert orig.dtype == clone.dtype for c1, c2 in zip(orig.children, clone.children, strict=True): check(c1, c2) @@ -43,14 +61,21 @@ def test_cloning(): PsAssignment(y, x / c1), PsBlock([PsAssignment(x, c1 * y), PsAssignment(z, c2 + c1 * z)]), PsConditional( - y, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")]) + q, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")]) + ), + PsDeclaration( + m, + PsArrayInitList([ + [x, y, one_f + x], + [one_f, c2, z] + ]) ), PsPragma("omp parallel for"), PsLoop( - x, - y, - z, - one, + a, + b, + c, + one_i, PsBlock( [ PsComment("Loop body"), @@ -58,12 +83,55 @@ def test_cloning(): PsAssignment(x, y), PsPragma("#pragma clang loop vectorize(enable)"), PsStatement( - PsDeref(PsCast(Ptr(Fp(32)), z)) - + PsSubscript(z, one + one + one) + PsMemAcc(PsCast(Ptr(Fp(32)), z), one_i) + + PsCast(Fp(32), PsSubscript(m, (one_i + one_i + one_i, b + one_i))) ), ] ), ), ]: + ast = typify(ast) ast_clone = ast.clone() check(ast, ast_clone) + + +def test_buffer_acc(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + from pystencils import fields + + f, g = fields("f, g(3): [2D]") + a, b = [ctx.get_symbol(n, ctx.index_dtype) for n in "ab"] + + f_buf = ctx.get_buffer(f) + + f_acc = PsBufferAcc(f_buf.base_pointer, [PsExpression.make(i) for i in (a, b)] + [factory.parse_index(0)]) + assert f_acc.buffer == f_buf + assert f_acc.base_pointer.structurally_equal(PsSymbolExpr(f_buf.base_pointer)) + + f_acc_clone = f_acc.clone() + assert f_acc_clone is not f_acc + + assert f_acc_clone.buffer == f_buf + assert f_acc_clone.base_pointer.structurally_equal(PsSymbolExpr(f_buf.base_pointer)) + assert len(f_acc_clone.index) == 3 + assert f_acc_clone.index[0].structurally_equal(PsSymbolExpr(ctx.get_symbol("a"))) + assert f_acc_clone.index[1].structurally_equal(PsSymbolExpr(ctx.get_symbol("b"))) + + g_buf = ctx.get_buffer(g) + + g_acc = PsBufferAcc(g_buf.base_pointer, [PsExpression.make(i) for i in (a, b)] + [factory.parse_index(2)]) + assert g_acc.buffer == g_buf + assert g_acc.base_pointer.structurally_equal(PsSymbolExpr(g_buf.base_pointer)) + + second_bptr = PsExpression.make(ctx.get_symbol("data_g_interior", g_buf.base_pointer.dtype)) + second_bptr.symbol.add_property(BufferBasePtr(g_buf)) + g_acc.base_pointer = second_bptr + + assert g_acc.base_pointer == second_bptr + assert g_acc.buffer == g_buf + + # cannot change base pointer to different buffer + with pytest.raises(ValueError): + g_acc.base_pointer = PsExpression.make(f_buf.base_pointer) diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py index 8fb44e748ad828b96e9ac46042527a7218828da5..ef4806314eb52c7389bf583027bb808c42049213 100644 --- a/tests/nbackend/test_code_printing.py +++ b/tests/nbackend/test_code_printing.py @@ -3,10 +3,9 @@ from pystencils import Target from pystencils.backend.ast.expressions import PsExpression from pystencils.backend.ast.structural import PsAssignment, PsLoop, PsBlock from pystencils.backend.kernelfunction import KernelFunction -from pystencils.backend.symbols import PsSymbol +from pystencils.backend.memory import PsSymbol, PsBuffer from pystencils.backend.constants import PsConstant from pystencils.backend.literals import PsLiteral -from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer from pystencils.types.quick import Fp, SInt, UInt, Bool from pystencils.backend.emission import CAstPrinter @@ -55,14 +54,15 @@ def test_printing_integer_functions(): PsBitwiseOr, PsBitwiseXor, PsIntDiv, - PsRem + PsRem, ) expr = PsBitwiseAnd( PsBitwiseXor( PsBitwiseXor(j, k), PsBitwiseOr(PsLeftShift(i, PsRightShift(j, k)), PsIntDiv(i, k)), - ) + PsRem(i, k), + ) + + PsRem(i, k), i, ) code = cprint(expr) @@ -154,3 +154,32 @@ def test_ternary(): expr = PsTernary(PsTernary(p, q, PsOr(p, q)), x, y) code = cprint(expr) assert code == "(p ? q : p || q) ? x : y" + + +def test_arrays(): + import sympy as sp + from pystencils import Assignment + from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory + + ctx = KernelCreationContext(default_dtype=SInt(32)) + factory = AstFactory(ctx) + cprint = CAstPrinter() + + arr_1d = factory.parse_sympy(Assignment(sp.Symbol("a1d"), sp.Tuple(1, 2, 3, 4, 5))) + code = cprint(arr_1d) + assert code == "int32_t a1d[5] = { 1, 2, 3, 4, 5 };" + + arr_2d = factory.parse_sympy( + Assignment(sp.Symbol("a2d"), sp.Tuple((1, -1), (2, -2))) + ) + code = cprint(arr_2d) + assert code == "int32_t a2d[2][2] = { { 1, -1 }, { 2, -2 } };" + + arr_3d = factory.parse_sympy( + Assignment(sp.Symbol("a3d"), sp.Tuple(((1, -1), (2, -2)), ((3, -3), (4, -4)))) + ) + code = cprint(arr_3d) + assert ( + code + == "int32_t a3d[2][2][2] = { { { 1, -1 }, { 2, -2 } }, { { 3, -3 }, { 4, -4 } } };" + ) diff --git a/tests/nbackend/test_cpujit.py b/tests/nbackend/test_cpujit.py index b621829ad7e72383ec6651015b7813e0a009839b..648112ef95bf5d6c3181f5c3c2527dd870220f0e 100644 --- a/tests/nbackend/test_cpujit.py +++ b/tests/nbackend/test_cpujit.py @@ -3,11 +3,10 @@ import pytest from pystencils import Target # from pystencils.backend.constraints import PsKernelParamsConstraint -from pystencils.backend.symbols import PsSymbol +from pystencils.backend.memory import PsSymbol, PsBuffer from pystencils.backend.constants import PsConstant -from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer -from pystencils.backend.ast.expressions import PsArrayAccess, PsExpression +from pystencils.backend.ast.expressions import PsBufferAcc, PsExpression from pystencils.backend.ast.structural import PsAssignment, PsBlock, PsLoop from pystencils.backend.kernelfunction import KernelFunction @@ -21,8 +20,8 @@ import numpy as np def test_pairwise_addition(): idx_type = SInt(64) - u = PsLinearizedArray("u", Fp(64, const=True), (...,), (...,), index_dtype=idx_type) - v = PsLinearizedArray("v", Fp(64), (...,), (...,), index_dtype=idx_type) + u = PsBuffer("u", Fp(64, const=True), (...,), (...,), index_dtype=idx_type) + v = PsBuffer("v", Fp(64), (...,), (...,), index_dtype=idx_type) u_data = PsArrayBasePointer("u_data", u) v_data = PsArrayBasePointer("v_data", v) @@ -34,8 +33,8 @@ def test_pairwise_addition(): two = PsExpression.make(PsConstant(2, idx_type)) update = PsAssignment( - PsArrayAccess(v_data, loop_ctr), - PsArrayAccess(u_data, two * loop_ctr) + PsArrayAccess(u_data, two * loop_ctr + one) + PsBufferAcc(v_data, loop_ctr), + PsBufferAcc(u_data, two * loop_ctr) + PsBufferAcc(u_data, two * loop_ctr + one) ) loop = PsLoop( diff --git a/tests/nbackend/test_extensions.py b/tests/nbackend/test_extensions.py index 16e610a552b426cc4245e2e5c4ee36663f6c2bfa..b1403185cff814534d5911a5caf57c4a18892c00 100644 --- a/tests/nbackend/test_extensions.py +++ b/tests/nbackend/test_extensions.py @@ -3,7 +3,7 @@ import sympy as sp from pystencils import make_slice, Field, Assignment from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, FullIterationSpace -from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations +from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations, LowerToC from pystencils.backend.literals import PsLiteral from pystencils.backend.emission import CAstPrinter from pystencils.backend.ast.expressions import PsExpression, PsSubscript @@ -18,13 +18,13 @@ def test_literals(): f = Field.create_generic("f", 3) x = sp.Symbol("x") - cells = PsExpression.make(PsLiteral("CELLS", Arr(Int(64, const=True), 3))) + cells = PsExpression.make(PsLiteral("CELLS", Arr(Int(64), (3,), const=True))) global_constant = PsExpression.make(PsLiteral("C", ctx.default_dtype)) loop_slice = make_slice[ - 0:PsSubscript(cells, factory.parse_index(0)), - 0:PsSubscript(cells, factory.parse_index(1)), - 0:PsSubscript(cells, factory.parse_index(2)), + 0:PsSubscript(cells, (factory.parse_index(0),)), + 0:PsSubscript(cells, (factory.parse_index(1),)), + 0:PsSubscript(cells, (factory.parse_index(2),)), ] ispace = FullIterationSpace.create_from_slice(ctx, loop_slice) @@ -46,6 +46,9 @@ def test_literals(): hoist = HoistLoopInvariantDeclarations(ctx) ast = hoist(ast) + lower = LowerToC(ctx) + ast = lower(ast) + assert isinstance(ast, PsBlock) assert len(ast.statements) == 2 assert ast.statements[0] == x_decl diff --git a/tests/nbackend/test_memory.py b/tests/nbackend/test_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..5841e0f4f0185483f47e544a2006f65f71daae12 --- /dev/null +++ b/tests/nbackend/test_memory.py @@ -0,0 +1,50 @@ +import pytest + +from dataclasses import dataclass +from pystencils.backend.memory import PsSymbol, PsSymbolProperty, UniqueSymbolProperty + + +def test_properties(): + @dataclass(frozen=True) + class NumbersProperty(PsSymbolProperty): + n: int + x: float + + @dataclass(frozen=True) + class StringProperty(PsSymbolProperty): + s: str + + @dataclass(frozen=True) + class MyUniqueProperty(UniqueSymbolProperty): + val: int + + s = PsSymbol("s") + + assert not s.properties + + s.add_property(NumbersProperty(42, 8.71)) + assert s.properties == {NumbersProperty(42, 8.71)} + + # no duplicates + s.add_property(NumbersProperty(42, 8.71)) + assert s.properties == {NumbersProperty(42, 8.71)} + + s.add_property(StringProperty("pystencils")) + assert s.properties == {NumbersProperty(42, 8.71), StringProperty("pystencils")} + + assert s.get_properties(NumbersProperty) == {NumbersProperty(42, 8.71)} + + assert not s.get_properties(MyUniqueProperty) + + s.add_property(MyUniqueProperty(13)) + assert s.get_properties(MyUniqueProperty) == {MyUniqueProperty(13)} + + # Adding the same one again does not raise + s.add_property(MyUniqueProperty(13)) + assert s.get_properties(MyUniqueProperty) == {MyUniqueProperty(13)} + + with pytest.raises(ValueError): + s.add_property(MyUniqueProperty(14)) + + s.remove_property(MyUniqueProperty(13)) + assert not s.get_properties(MyUniqueProperty) diff --git a/tests/nbackend/transformations/test_canonical_clone.py b/tests/nbackend/transformations/test_canonical_clone.py index b158b91781b49f8d589a3da3b266e8c2137fceab..b5e100ea590283ff42d83308f2a592b87c1231d2 100644 --- a/tests/nbackend/transformations/test_canonical_clone.py +++ b/tests/nbackend/transformations/test_canonical_clone.py @@ -22,8 +22,8 @@ def test_clone_entire_ast(): rho = sp.Symbol("rho") u = sp.symbols("u_:2") - cx = TypedSymbol("cx", Arr(ctx.default_dtype)) - cy = TypedSymbol("cy", Arr(ctx.default_dtype)) + cx = TypedSymbol("cx", Arr(ctx.default_dtype, (5,))) + cy = TypedSymbol("cy", Arr(ctx.default_dtype, (5,))) cxs = sp.IndexedBase(cx, shape=(5,)) cys = sp.IndexedBase(cy, shape=(5,)) diff --git a/tests/nbackend/transformations/test_canonicalize_symbols.py b/tests/nbackend/transformations/test_canonicalize_symbols.py index 6a478556469b6409507ae7b4ccc545311ab38462..a11e9bd1353ef98c5ae23e0d86ddce3e5c9d579c 100644 --- a/tests/nbackend/transformations/test_canonicalize_symbols.py +++ b/tests/nbackend/transformations/test_canonicalize_symbols.py @@ -52,7 +52,7 @@ def test_deduplication(): assert canonicalize.get_last_live_symbols() == { ctx.find_symbol("y"), ctx.find_symbol("z"), - ctx.get_array(f).base_pointer, + ctx.get_buffer(f).base_pointer, } assert ctx.find_symbol("x") is not None diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py index 92bb5c947b4bc2e4b6d50064a9c07874cdda43cf..4c18970086b4537dec2ae974ddf6242da57b591e 100644 --- a/tests/nbackend/transformations/test_constant_elimination.py +++ b/tests/nbackend/transformations/test_constant_elimination.py @@ -1,6 +1,6 @@ from pystencils.backend.kernelcreation import KernelCreationContext, Typifier from pystencils.backend.ast.expressions import PsExpression, PsConstantExpr -from pystencils.backend.symbols import PsSymbol +from pystencils.backend.memory import PsSymbol from pystencils.backend.constants import PsConstant from pystencils.backend.transformations import EliminateConstants diff --git a/tests/nbackend/transformations/test_hoist_invariants.py b/tests/nbackend/transformations/test_hoist_invariants.py index 15514f1da8a5f71102242f51192c9a453def957a..daa2760c0b376dc0bb1f2ca59703a15efc5c2312 100644 --- a/tests/nbackend/transformations/test_hoist_invariants.py +++ b/tests/nbackend/transformations/test_hoist_invariants.py @@ -144,14 +144,14 @@ def test_hoist_arrays(): const_arr_symb = TypedSymbol( "const_arr", - Arr(Fp(64, const=True), 10), + Arr(Fp(64), (10,), const=True), ) const_array_decl = factory.parse_sympy(Assignment(const_arr_symb, tuple(range(10)))) const_arr = sp.IndexedBase(const_arr_symb, shape=(10,)) arr_symb = TypedSymbol( "arr", - Arr(Fp(64, const=False), 10), + Arr(Fp(64), (10,), const=False), ) array_decl = factory.parse_sympy(Assignment(arr_symb, tuple(range(10)))) arr = sp.IndexedBase(arr_symb, shape=(10,)) diff --git a/tests/nbackend/transformations/test_lower_to_c.py b/tests/nbackend/transformations/test_lower_to_c.py new file mode 100644 index 0000000000000000000000000000000000000000..b557a7493f9a84cb13b511e8fca1f898823bc9bb --- /dev/null +++ b/tests/nbackend/transformations/test_lower_to_c.py @@ -0,0 +1,122 @@ +from functools import reduce +from operator import add + +from pystencils import fields, Assignment, make_slice, Field, FieldType +from pystencils.types import PsStructType, create_type + +from pystencils.backend.memory import BufferBasePtr +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) +from pystencils.backend.transformations import LowerToC + +from pystencils.backend.ast import dfs_preorder +from pystencils.backend.ast.expressions import ( + PsBufferAcc, + PsMemAcc, + PsSymbolExpr, + PsExpression, + PsLookup, + PsAddressOf, + PsCast, +) +from pystencils.backend.ast.structural import PsAssignment + + +def test_lower_buffer_accesses(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:42, :31]) + ctx.set_iteration_space(ispace) + + lower = LowerToC(ctx) + + f, g = fields("f(2), g(3): [2D]") + asm = Assignment(f.center(1), g[-1, 1](2)) + + f_buf = ctx.get_buffer(f) + g_buf = ctx.get_buffer(g) + + fasm = factory.parse_sympy(asm) + assert isinstance(fasm.lhs, PsBufferAcc) + assert isinstance(fasm.rhs, PsBufferAcc) + + fasm_lowered = lower(fasm) + assert isinstance(fasm_lowered, PsAssignment) + + assert isinstance(fasm_lowered.lhs, PsMemAcc) + assert isinstance(fasm_lowered.lhs.pointer, PsSymbolExpr) + assert fasm_lowered.lhs.pointer.symbol == f_buf.base_pointer + + zero = factory.parse_index(0) + expected_offset = reduce( + add, + ( + (PsExpression.make(dm.counter) + zero) * PsExpression.make(stride) + for dm, stride in zip(ispace.dimensions, f_buf.strides) + ), + ) + factory.parse_index(1) * PsExpression.make(f_buf.strides[-1]) + assert fasm_lowered.lhs.offset.structurally_equal(expected_offset) + + assert isinstance(fasm_lowered.rhs, PsMemAcc) + assert isinstance(fasm_lowered.rhs.pointer, PsSymbolExpr) + assert fasm_lowered.rhs.pointer.symbol == g_buf.base_pointer + + expected_offset = ( + (PsExpression.make(ispace.dimensions[0].counter) + factory.parse_index(-1)) + * PsExpression.make(g_buf.strides[0]) + + (PsExpression.make(ispace.dimensions[1].counter) + factory.parse_index(1)) + * PsExpression.make(g_buf.strides[1]) + + factory.parse_index(2) * PsExpression.make(g_buf.strides[-1]) + ) + assert fasm_lowered.rhs.offset.structurally_equal(expected_offset) + + +def test_lower_anonymous_structs(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:12]) + ctx.set_iteration_space(ispace) + + lower = LowerToC(ctx) + + stype = PsStructType( + [ + ("val", ctx.default_dtype), + ("x", ctx.index_dtype), + ] + ) + sfield = Field.create_generic("s", spatial_dimensions=1, dtype=stype) + f = Field.create_generic("f", 1, ctx.default_dtype, field_type=FieldType.CUSTOM) + + asm = Assignment(sfield.center("val"), f.absolute_access((sfield.center("x"),), (0,))) + + fasm = factory.parse_sympy(asm) + + sbuf = ctx.get_buffer(sfield) + + assert isinstance(fasm, PsAssignment) + assert isinstance(fasm.lhs, PsLookup) + + lowered_fasm = lower(fasm.clone()) + assert isinstance(lowered_fasm, PsAssignment) + + # Check type of sfield data pointer + for expr in dfs_preorder(lowered_fasm, lambda n: isinstance(n, PsSymbolExpr)): + if expr.symbol.name == sbuf.base_pointer.name: + assert expr.symbol.dtype == create_type("uint8_t * restrict") + + # Check LHS + assert isinstance(lowered_fasm.lhs, PsMemAcc) + assert isinstance(lowered_fasm.lhs.pointer, PsCast) + assert isinstance(lowered_fasm.lhs.pointer.operand, PsAddressOf) + assert isinstance(lowered_fasm.lhs.pointer.operand.operand, PsMemAcc) + type_erased_pointer = lowered_fasm.lhs.pointer.operand.operand.pointer + + assert isinstance(type_erased_pointer, PsSymbolExpr) + assert BufferBasePtr(sbuf) in type_erased_pointer.symbol.properties + assert type_erased_pointer.symbol.dtype == create_type("uint8_t * restrict") diff --git a/tests/types/test_types.py b/tests/types/test_types.py index 1cc2ae0e4a213df51ccf80578a6ba028771e4f0c..165d572de5d191e759e5d8a6bea06c0f71884374 100644 --- a/tests/types/test_types.py +++ b/tests/types/test_types.py @@ -151,6 +151,48 @@ def test_struct_types(): assert t.itemsize == numpy_type.itemsize == 16 +def test_array_types(): + t = PsArrayType(UInt(64), 42) + assert t.dim == 1 + assert t.shape == (42,) + assert not t.const + assert t.c_string() == "uint64_t[42]" + + assert t == PsArrayType(UInt(64), (42,)) + + t = PsArrayType(UInt(64), [3, 4, 5]) + assert t.dim == 3 + assert t.shape == (3, 4, 5) + assert not t.const + assert t.c_string() == "uint64_t[3][4][5]" + + t = PsArrayType(UInt(64, const=True), [3, 4, 5]) + assert t.dim == 3 + assert t.shape == (3, 4, 5) + assert not t.const + + t = PsArrayType(UInt(64), [3, 4, 5], const=True) + assert t.dim == 3 + assert t.shape == (3, 4, 5) + assert t.const + assert t.c_string() == "const uint64_t[3][4][5]" + + t = PsArrayType(UInt(64, const=True), [3, 4, 5], const=True) + assert t.dim == 3 + assert t.shape == (3, 4, 5) + assert t.const + + with pytest.raises(ValueError): + _ = PsArrayType(UInt(64), (3, 0, 1)) + + with pytest.raises(ValueError): + _ = PsArrayType(UInt(64), (3, 9, -1, 2)) + + # Nested arrays are disallowed + with pytest.raises(ValueError): + _ = PsArrayType(PsArrayType(Bool(), (2,)), (3, 1)) + + def test_pickle(): types = [ Bool(const=True), @@ -165,8 +207,8 @@ def test_pickle(): Fp(width=16, const=True), PsStructType([("x", UInt(32)), ("y", UInt(32)), ("val", Fp(64))], "myStruct"), PsStructType([("data", Fp(32))], "None"), - PsArrayType(Fp(16, const=True), 42), - PsArrayType(PsVectorType(Fp(32), 8, const=False), 42) + PsArrayType(Fp(16), (42,), const=True), + PsArrayType(PsVectorType(Fp(32), 8), (42,)) ] dumped = pickle.dumps(types)