diff --git a/src/pystencils/backend/arrays.py b/src/pystencils/backend/arrays.py deleted file mode 100644 index 9aefeaf62fe65fad045a6d2f20e72ce732f8aa8b..0000000000000000000000000000000000000000 --- a/src/pystencils/backend/arrays.py +++ /dev/null @@ -1,194 +0,0 @@ -from __future__ import annotations - -from typing import Sequence -from types import EllipsisType - -from abc import ABC - -from .constants import PsConstant -from ..types import ( - PsType, - PsPointerType, - PsIntegerType, - PsUnsignedIntegerType, -) - -from .symbols import PsSymbol -from ..defaults import DEFAULTS - - -class PsLinearizedArray: - """Class to model N-dimensional contiguous arrays. - - **Memory Layout, Shape and Strides** - - The memory layout of an array is defined by its shape and strides. - Both shape and stride entries may either be constants or special variables associated with - exactly one array. - - Shape and strides may be specified at construction in the following way. - For constant entries, their value must be given as an integer. - For variable shape entries and strides, the Ellipsis `...` must be passed instead. - Internally, the passed ``index_dtype`` will be used to create typed constants (`PsConstant`) - and variables (`PsArrayShapeSymbol` and `PsArrayStrideSymbol`) from the passed values. - """ - - def __init__( - self, - name: str, - element_type: PsType, - shape: Sequence[int | str | EllipsisType], - strides: Sequence[int | str | EllipsisType], - index_dtype: PsIntegerType = DEFAULTS.index_dtype, - ): - self._name = name - self._element_type = element_type - self._index_dtype = index_dtype - - if len(shape) != len(strides): - raise ValueError("Shape and stride tuples must have the same length") - - def make_shape(coord, name_or_val): - match name_or_val: - case EllipsisType(): - return PsArrayShapeSymbol(DEFAULTS.field_shape_name(name, coord), self, coord) - case str(): - return PsArrayShapeSymbol(name_or_val, self, coord) - case _: - return PsConstant(name_or_val, index_dtype) - - self._shape: tuple[PsArrayShapeSymbol | PsConstant, ...] = tuple( - make_shape(i, s) for i, s in enumerate(shape) - ) - - def make_stride(coord, name_or_val): - match name_or_val: - case EllipsisType(): - return PsArrayStrideSymbol(DEFAULTS.field_stride_name(name, coord), self, coord) - case str(): - return PsArrayStrideSymbol(name_or_val, self, coord) - case _: - return PsConstant(name_or_val, index_dtype) - - self._strides: tuple[PsArrayStrideSymbol | PsConstant, ...] = tuple( - make_stride(i, s) for i, s in enumerate(strides) - ) - - self._base_ptr = PsArrayBasePointer(DEFAULTS.field_pointer_name(name), self) - - @property - def name(self): - """The array's name""" - return self._name - - @property - def base_pointer(self) -> PsArrayBasePointer: - """The array's base pointer""" - return self._base_ptr - - @property - def shape(self) -> tuple[PsArrayShapeSymbol | PsConstant, ...]: - """The array's shape, expressed using `PsConstant` and `PsArrayShapeSymbol`""" - return self._shape - - @property - def strides(self) -> tuple[PsArrayStrideSymbol | PsConstant, ...]: - """The array's strides, expressed using `PsConstant` and `PsArrayStrideSymbol`""" - return self._strides - - @property - def index_type(self) -> PsIntegerType: - return self._index_dtype - - @property - def element_type(self) -> PsType: - return self._element_type - - def __repr__(self) -> str: - return ( - f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])" - ) - - -class PsArrayAssocSymbol(PsSymbol, ABC): - """A variable that is associated to an array. - - Instances of this class represent pointers and indexing information bound - to a particular array. - """ - - __match_args__ = ("name", "dtype", "array") - - def __init__(self, name: str, dtype: PsType, array: PsLinearizedArray): - super().__init__(name, dtype) - self._array = array - - @property - def array(self) -> PsLinearizedArray: - return self._array - - -class PsArrayBasePointer(PsArrayAssocSymbol): - def __init__(self, name: str, array: PsLinearizedArray): - dtype = PsPointerType(array.element_type) - super().__init__(name, dtype, array) - - self._array = array - - -class TypeErasedBasePointer(PsArrayBasePointer): - """Base pointer for arrays whose element type has been erased. - - Used primarily for arrays of anonymous structs.""" - - def __init__(self, name: str, array: PsLinearizedArray): - dtype = PsPointerType(PsUnsignedIntegerType(8)) - super(PsArrayBasePointer, self).__init__(name, dtype, array) - - self._array = array - - -class PsArrayShapeSymbol(PsArrayAssocSymbol): - """Variable that represents an array's shape in one coordinate. - - Do not instantiate this class yourself, but only use its instances - as provided by `PsLinearizedArray.shape`. - """ - - __match_args__ = PsArrayAssocSymbol.__match_args__ + ("coordinate",) - - def __init__( - self, - name: str, - array: PsLinearizedArray, - coordinate: int, - ): - super().__init__(name, array.index_type, array) - self._coordinate = coordinate - - @property - def coordinate(self) -> int: - return self._coordinate - - -class PsArrayStrideSymbol(PsArrayAssocSymbol): - """Variable that represents an array's stride in one coordinate. - - Do not instantiate this class yourself, but only use its instances - as provided by `PsLinearizedArray.strides`. - """ - - __match_args__ = PsArrayAssocSymbol.__match_args__ + ("coordinate",) - - def __init__( - self, - name: str, - array: PsLinearizedArray, - coordinate: int, - ): - super().__init__(name, array.index_type, array) - self._coordinate = coordinate - - @property - def coordinate(self) -> int: - return self._coordinate diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py index 5b37470cc89dc8b304d9fd033ea66ec67b9f8d8a..4943ac4095b9e0f1e5763bb73a70e6756b2950a2 100644 --- a/src/pystencils/backend/ast/analysis.py +++ b/src/pystencils/backend/ast/analysis.py @@ -30,7 +30,7 @@ from .expressions import ( PsTernary, ) -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..exceptions import PsInternalCompilerError from ...types import PsNumericType diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 151f86c6e3db4e5c4f7914d7da9e94240412cdb2..286958ee2f226eb2b6a92b2d7f5959e399418ddf 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -7,10 +7,10 @@ import operator import numpy as np from numpy.typing import NDArray -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..constants import PsConstant from ..literals import PsLiteral -from ..arrays import PsLinearizedArray, PsArrayBasePointer +from ..memory import PsBuffer from ..functions import PsFunction from ...types import ( PsType, @@ -315,7 +315,7 @@ class PsArrayAccess(PsMemAcc): self._ptr = expr @property - def array(self) -> PsLinearizedArray: + def array(self) -> PsBuffer: return self._base_ptr.array @property diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index cd3aae30d35061ab6c15c338a735aaecca83a141..57244c03b6413d3fd8c7b521618cb9021b0e2037 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -4,7 +4,7 @@ from types import NoneType from .astnode import PsAstNode, PsLeafMixIn from .expressions import PsExpression, PsLvalue, PsSymbolExpr -from ..symbols import PsSymbol +from ..memory import PsSymbol from .util import failing_cast diff --git a/src/pystencils/backend/ast/util.py b/src/pystencils/backend/ast/util.py index 72aff0a01c83d5c0df5acca9c35691a12f763a2d..b7bde603f030b0814a118899dd7c45246be22418 100644 --- a/src/pystencils/backend/ast/util.py +++ b/src/pystencils/backend/ast/util.py @@ -2,8 +2,8 @@ from __future__ import annotations from typing import Any, TYPE_CHECKING, cast from ..exceptions import PsInternalCompilerError -from ..symbols import PsSymbol -from ..arrays import PsLinearizedArray +from ..memory import PsSymbol +from ..memory import PsBuffer from ...types import PsDereferencableType @@ -47,7 +47,7 @@ class AstEqWrapper: def determine_memory_object( expr: PsExpression, -) -> tuple[PsSymbol | PsLinearizedArray | None, bool]: +) -> tuple[PsSymbol | PsBuffer | None, bool]: """Return the memory object accessed by the given expression, together with its constness Returns: diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 579d4764835729bb98e17a085898b8b7f1612f27..e0a1f4242f2dadde44b23501ecd6af6c1076b834 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -55,7 +55,7 @@ from .ast.expressions import ( from .extensions.foreign_ast import PsForeignExpression -from .symbols import PsSymbol +from .memory import PsSymbol from ..types import PsScalarType, PsArrayType from .kernelfunction import KernelFunction, GpuKernelFunction diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index 0dc60b1b1c1f4478d6eb82496f6a80be6fb1d626..d6084dbc7bca97aa8ff6bf3f1f9766e4e70c0561 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -8,7 +8,7 @@ from ..ast import PsAstNode from ..ast.expressions import PsExpression, PsSymbolExpr, PsConstantExpr from ..ast.structural import PsLoop, PsBlock, PsAssignment -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..constants import PsConstant from .context import KernelCreationContext diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 73e3c70cc315c85c9ec01d2db401d10eaa70c53f..e75144dee34051311ba14947d444a883f9d26717 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -9,14 +9,13 @@ from ...defaults import DEFAULTS from ...field import Field, FieldType from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType -from ..symbols import PsSymbol -from ..arrays import PsLinearizedArray +from ..memory import PsSymbol, PsBuffer, FieldShape, FieldStride +from ..constants import PsConstant from ...types import ( PsType, PsIntegerType, PsNumericType, - PsScalarType, - PsStructType, + PsPointerType, deconstify, ) from ..constraints import KernelParamsConstraint @@ -221,64 +220,14 @@ class KernelCreationContext: else: return - arr_shape: list[str | int] | None = None - arr_strides: list[str | int] | None = None - - def normalize_type(s: TypedSymbol) -> PsIntegerType: - match s.dtype: - case DynamicType.INDEX_TYPE: - return self.index_dtype - case DynamicType.NUMERIC_TYPE: - if isinstance(self.default_dtype, PsIntegerType): - return self.default_dtype - else: - raise KernelConstraintsError( - f"Cannot use non-integer default numeric type {self.default_dtype} " - f"in field indexing symbol {s}." - ) - case PsIntegerType(): - return deconstify(s.dtype) - case _: - raise KernelConstraintsError( - f"Invalid data type for field indexing symbol {s}: {s.dtype}" - ) - - # Check field constraints and add to collection + # Check field constraints, create buffer, and add them to the collection match field.field_type: case FieldType.GENERIC | FieldType.STAGGERED | FieldType.STAGGERED_FLUX: + buf = self._create_regular_field_buffer(field) self._fields_collection.domain_fields.add(field) case FieldType.BUFFER: - if field.spatial_dimensions != 1: - raise KernelConstraintsError( - f"Invalid spatial shape of buffer field {field.name}: {field.spatial_dimensions}. " - "Buffer fields must be one-dimensional." - ) - - if field.index_dimensions > 1: - raise KernelConstraintsError( - f"Invalid index shape of buffer field {field.name}: {field.spatial_dimensions}. " - "Buffer fields can have at most one index dimension." - ) - - num_entries = field.index_shape[0] if field.index_shape else 1 - if not isinstance(num_entries, int): - raise KernelConstraintsError( - f"Invalid index shape of buffer field {field.name}: {num_entries}. " - "Buffer fields cannot have variable index shape." - ) - - buffer_len = field.spatial_shape[0] - - if isinstance(buffer_len, TypedSymbol): - idx_type = normalize_type(buffer_len) - arr_shape = [buffer_len.name, num_entries] - else: - idx_type = DEFAULTS.index_dtype - arr_shape = [buffer_len, num_entries] - - arr_strides = [num_entries, 1] - + buf = self._create_buffer_field_buffer(field) self._fields_collection.buffer_fields.add(field) case FieldType.INDEXED: @@ -287,6 +236,7 @@ class KernelCreationContext: f"Invalid spatial shape of index field {field.name}: {field.spatial_dimensions}. " "Index fields must be one-dimensional." ) + buf = self._create_regular_field_buffer(field) self._fields_collection.index_fields.add(field) case FieldType.CUSTOM: @@ -295,59 +245,14 @@ class KernelCreationContext: case _: assert False, "unreachable code" - # For non-buffer fields, determine shape and strides - - if arr_shape is None: - idx_types = set( - normalize_type(s) - for s in chain(field.shape, field.strides) - if isinstance(s, TypedSymbol) - ) - - if len(idx_types) > 1: - raise KernelConstraintsError( - f"Multiple incompatible types found in index symbols of field {field}: " - f"{idx_types}" - ) - idx_type = idx_types.pop() if len(idx_types) > 0 else self.index_dtype - - arr_shape = [ - (s.name if isinstance(s, TypedSymbol) else s) for s in field.shape - ] - - arr_strides = [ - (s.name if isinstance(s, TypedSymbol) else s) for s in field.strides - ] - - # The frontend doesn't quite agree with itself on how to model - # fields with trivial index dimensions. Sometimes the index_shape is empty, - # sometimes its (1,). This is canonicalized here. - if not field.index_shape: - arr_shape += [1] - arr_strides += [1] - - # Add array - assert arr_strides is not None - assert idx_type is not None - - assert isinstance(field.dtype, (PsScalarType, PsStructType)) - element_type = field.dtype - - arr = PsLinearizedArray( - field.name, element_type, arr_shape, arr_strides, idx_type - ) - - self._fields_and_arrays[field.name] = FieldArrayPair(field, arr) - for symb in chain([arr.base_pointer], arr.shape, arr.strides): - if isinstance(symb, PsSymbol): - self.add_symbol(symb) + self._fields_and_arrays[field.name] = FieldArrayPair(field, buf) @property - def arrays(self) -> Iterable[PsLinearizedArray]: + def arrays(self) -> Iterable[PsBuffer]: # return self._fields_and_arrays.values() yield from (item.array for item in self._fields_and_arrays.values()) - def get_array(self, field: Field) -> PsLinearizedArray: + def get_buffer(self, field: Field) -> PsBuffer: """Retrieve the underlying array for a given field. If the given field was not previously registered using `add_field`, @@ -393,3 +298,114 @@ class KernelCreationContext: def require_header(self, header: str): self._req_headers.add(header) + + # ----------- Internals --------------------------------------------------------------------- + + def _normalize_type(self, s: TypedSymbol) -> PsIntegerType: + match s.dtype: + case DynamicType.INDEX_TYPE: + return self.index_dtype + case DynamicType.NUMERIC_TYPE: + if isinstance(self.default_dtype, PsIntegerType): + return self.default_dtype + else: + raise KernelConstraintsError( + f"Cannot use non-integer default numeric type {self.default_dtype} " + f"in field indexing symbol {s}." + ) + case PsIntegerType(): + return deconstify(s.dtype) + case _: + raise KernelConstraintsError( + f"Invalid data type for field indexing symbol {s}: {s.dtype}" + ) + + def _create_regular_field_buffer(self, field: Field) -> PsBuffer: + idx_types = set( + self._normalize_type(s) + for s in chain(field.shape, field.strides) + if isinstance(s, TypedSymbol) + ) + + if len(idx_types) > 1: + raise KernelConstraintsError( + f"Multiple incompatible types found in index symbols of field {field}: " + f"{idx_types}" + ) + + idx_type = idx_types.pop() if len(idx_types) > 0 else self.index_dtype + + def convert_size(s: TypedSymbol | int) -> PsSymbol | PsConstant: + if isinstance(s, TypedSymbol): + return self.get_symbol(s.name, idx_type) + else: + return PsConstant(s, idx_type) + + buf_shape = [convert_size(s) for s in field.shape] + buf_strides = [convert_size(s) for s in field.strides] + + # The frontend doesn't quite agree with itself on how to model + # fields with trivial index dimensions. Sometimes the index_shape is empty, + # sometimes its (1,). This is canonicalized here. + if not field.index_shape: + buf_shape += [convert_size(1)] + buf_strides += [convert_size(1)] + + for i, size in enumerate(buf_shape): + if isinstance(size, PsSymbol): + size.add_property(FieldShape(field, i)) + + for i, stride in enumerate(buf_strides): + if isinstance(stride, PsSymbol): + stride.add_property(FieldStride(field, i)) + + base_ptr = self.get_symbol( + DEFAULTS.field_pointer_name(field.name), + PsPointerType(field.dtype, restrict=True), + ) + + return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides) + + def _create_buffer_field_buffer(self, field: Field) -> PsBuffer: + if field.spatial_dimensions != 1: + raise KernelConstraintsError( + f"Invalid spatial shape of buffer field {field.name}: {field.spatial_dimensions}. " + "Buffer fields must be one-dimensional." + ) + + if field.index_dimensions > 1: + raise KernelConstraintsError( + f"Invalid index shape of buffer field {field.name}: {field.spatial_dimensions}. " + "Buffer fields can have at most one index dimension." + ) + + num_entries = field.index_shape[0] if field.index_shape else 1 + if not isinstance(num_entries, int): + raise KernelConstraintsError( + f"Invalid index shape of buffer field {field.name}: {num_entries}. " + "Buffer fields cannot have variable index shape." + ) + + buffer_len = field.spatial_shape[0] + buf_shape: list[PsSymbol | PsConstant] + + if isinstance(buffer_len, TypedSymbol): + idx_type = self._normalize_type(buffer_len) + len_symb = self.get_symbol(buffer_len.name, idx_type) + len_symb.add_property(FieldShape(field, 0)) + buf_shape = [len_symb, PsConstant(num_entries, idx_type)] + else: + idx_type = DEFAULTS.index_dtype + buf_shape = [ + PsConstant(buffer_len, idx_type), + PsConstant(num_entries, idx_type), + ] + + buf_strides = [PsConstant(num_entries, idx_type), PsConstant(1, idx_type)] + + base_ptr = self.get_symbol( + DEFAULTS.field_pointer_name(field.name), + PsPointerType(field.dtype, restrict=True), + ) + + return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 0ae5a0d1be80ee7a02c667372c30620aef6e77fd..11ac929bfadb9cb9e62ddd8bab29d10c8d956d8b 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -309,7 +309,7 @@ class FreezeExpressions: def map_Access(self, access: Field.Access): field = access.field - array = self._ctx.get_array(field) + array = self._ctx.get_buffer(field) ptr = array.base_pointer offsets: list[PsExpression] = [ diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 8175fffed9782d9271262b7bc220e4dcdc208705..bae0328e4348836463f2d6831fc1905855354548 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -9,10 +9,9 @@ from ...defaults import DEFAULTS from ...simp import AssignmentCollection from ...field import Field, FieldType -from ..symbols import PsSymbol +from ..memory import PsSymbol, PsBuffer from ..constants import PsConstant from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem -from ..arrays import PsLinearizedArray from ..ast.util import failing_cast from ...types import PsStructType, constify from ..exceptions import PsInputError, KernelConstraintsError @@ -74,7 +73,7 @@ class FullIterationSpace(IterationSpace): ) -> FullIterationSpace: """Create an iteration space over an archetype field with ghost layers.""" - archetype_array = ctx.get_array(archetype_field) + archetype_array = ctx.get_buffer(archetype_field) dim = archetype_field.spatial_dimensions counters = [ @@ -142,7 +141,7 @@ class FullIterationSpace(IterationSpace): archetype_size: tuple[PsSymbol | PsConstant | None, ...] if archetype_field is not None: - archetype_array = ctx.get_array(archetype_field) + archetype_array = ctx.get_buffer(archetype_field) if archetype_field.spatial_dimensions != dim: raise ValueError( @@ -281,7 +280,7 @@ class SparseIterationSpace(IterationSpace): def __init__( self, spatial_indices: Sequence[PsSymbol], - index_list: PsLinearizedArray, + index_list: PsBuffer, coordinate_members: Sequence[PsStructType.Member], sparse_counter: PsSymbol, ): @@ -291,7 +290,7 @@ class SparseIterationSpace(IterationSpace): self._sparse_counter = sparse_counter @property - def index_list(self) -> PsLinearizedArray: + def index_list(self) -> PsBuffer: return self._index_list @property @@ -365,7 +364,7 @@ def create_sparse_iteration_space( # Determine index field if index_field is not None: - idx_arr = ctx.get_array(index_field) + idx_arr = ctx.get_buffer(index_field) idx_struct_type: PsStructType = failing_cast(PsStructType, idx_arr.element_type) for coord in coord_members: diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py index a3213350e2c21b91a9e3c41f2e582992f293d7fb..568b6c7994add39d9e952cad1364db51f04ceefa 100644 --- a/src/pystencils/backend/kernelfunction.py +++ b/src/pystencils/backend/kernelfunction.py @@ -8,8 +8,7 @@ from .._deprecation import _deprecated from .ast.structural import PsBlock from .ast.analysis import collect_required_headers, collect_undefined_symbols -from .arrays import PsArrayShapeSymbol, PsArrayStrideSymbol, PsArrayBasePointer -from .symbols import PsSymbol +from .memory import PsSymbol from .kernelcreation.context import KernelCreationContext from .platforms import Platform, GpuThreadsRange diff --git a/src/pystencils/backend/memory.py b/src/pystencils/backend/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..b0333bd6677ac163a23aaa9418985c9cb8c4fda5 --- /dev/null +++ b/src/pystencils/backend/memory.py @@ -0,0 +1,207 @@ +from __future__ import annotations +from typing import ClassVar, Sequence +from itertools import chain +from dataclasses import dataclass + +from ..types import PsType, PsTypeError, deconstify, PsIntegerType +from .exceptions import PsInternalCompilerError +from ..field import Field +from .constants import PsConstant + + +@dataclass(frozen=True) +class PsSymbolProperty: + """Base class for symbol properties, which can be used to add additional information to symbols""" + + _unique: ClassVar[bool] = False + """Set to `True` in a subclass for property types of which only one instance per symbol is allowed.""" + + +@dataclass(frozen=True) +class FieldShape(PsSymbolProperty): + """Symbol acts as a shape parameter to a field.""" + + field: Field + coordinate: int + + +@dataclass(frozen=True) +class FieldStride(PsSymbolProperty): + """Symbol acts as a stride parameter to a field.""" + + field: Field + coordinate: int + + +@dataclass(frozen=True) +class FieldBasePtr(PsSymbolProperty): + """Symbol acts as a base pointer to a field.""" + + field: Field + + _unique: ClassVar[bool] = True + + +class PsSymbol: + """A mutable symbol with name and data type. + + Do not create objects of this class directly unless you know what you are doing; + instead obtain them from a `KernelCreationContext` through `KernelCreationContext.get_symbol`. + This way, the context can keep track of all symbols used in the translation run, + and uniqueness of symbols is ensured. + """ + + __match_args__ = ("name", "dtype") + + def __init__(self, name: str, dtype: PsType | None = None): + self._name = name + self._dtype = dtype + self._properties: set[PsSymbolProperty] = set() + + @property + def name(self) -> str: + return self._name + + @property + def dtype(self) -> PsType | None: + return self._dtype + + @dtype.setter + def dtype(self, value: PsType): + self._dtype = value + + def apply_dtype(self, dtype: PsType): + """Apply the given data type to this symbol, + raising a TypeError if it conflicts with a previously set data type.""" + + if self._dtype is not None and self._dtype != dtype: + raise PsTypeError( + f"Incompatible symbol data types: {self._dtype} and {dtype}" + ) + + self._dtype = dtype + + def get_dtype(self) -> PsType: + if self._dtype is None: + raise PsInternalCompilerError( + f"Symbol {self.name} had no type assigned yet" + ) + return self._dtype + + @property + def properties(self) -> frozenset[PsSymbolProperty]: + """Set of properties attached to this symbol""" + return frozenset(self._properties) + + def get_properties( + self, prop_type: type[PsSymbolProperty] + ) -> set[PsSymbolProperty]: + """Retrieve all properties of the given type attached to this symbol""" + return set(filter(lambda p: isinstance(p, prop_type), self._properties)) + + def add_property(self, property: PsSymbolProperty): + """Attach a property to this symbol""" + if property._unique and not self.get_properties(type(property)) <= {property}: + raise ValueError( + f"Cannot add second instance of unique property {type(property)} to symbol {self._name}." + ) + + self._properties.add(property) + + def remove_property(self, property: PsSymbolProperty): + """Remove a property from this symbol. Does nothing if the property is not attached.""" + self._properties.discard(property) + + def __str__(self) -> str: + dtype_str = "<untyped>" if self._dtype is None else str(self._dtype) + return f"{self._name}: {dtype_str}" + + def __repr__(self) -> str: + return f"PsSymbol({self._name}, {self._dtype})" + + +@dataclass(frozen=True) +class BufferBasePtr(PsSymbolProperty): + """Symbol acts as a base pointer to a field.""" + + buffer: PsBuffer + + _unique: ClassVar[bool] = True + + +class PsBuffer: + """N-dimensional contiguous linearized buffer in heap memory. + + `PsBuffer` models the memory buffers underlying the `Field` class + to the backend. Each buffer represents a contiguous block of memory + that is non-aliased and disjoint from all other buffers. + + Buffer shape and stride information are given either as constants or as symbols. + All indexing expressions must have the same data type, which will be selected as the buffer's + `index_dtype`. + + Each buffer has at least one base pointer, which can be retrieved via the `base_pointer` + property. + """ + + def __init__( + self, + name: str, + element_type: PsType, + base_ptr: PsSymbol, + shape: Sequence[PsSymbol | PsConstant], + strides: Sequence[PsSymbol | PsConstant], + ): + if len(shape) != len(strides): + raise ValueError("Buffer shape and stride tuples must have the same length") + + idx_types: set[PsType] = set( + deconstify(s.get_dtype()) for s in chain(shape, strides) + ) + if len(idx_types) > 1: + raise ValueError( + f"Conflicting data types in indexing symbols to buffer {name}: {idx_types}" + ) + + idx_dtype = idx_types.pop() + if not isinstance(idx_dtype, PsIntegerType): + raise ValueError( + f"Invalid index data type for buffer {name}: {idx_dtype}. Must be an integer type." + ) + + self._name = name + self._element_type = element_type + self._index_dtype = idx_dtype + + self._shape = tuple(shape) + self._strides = tuple(strides) + + base_ptr.add_property(BufferBasePtr(self)) + self._base_ptr = base_ptr + + @property + def name(self): + return self._name + + @property + def base_pointer(self) -> PsSymbol: + return self._base_ptr + + @property + def shape(self) -> tuple[PsSymbol | PsConstant, ...]: + return self._shape + + @property + def strides(self) -> tuple[PsSymbol | PsConstant, ...]: + return self._strides + + @property + def index_type(self) -> PsIntegerType: + return self._index_dtype + + @property + def element_type(self) -> PsType: + return self._element_type + + def __repr__(self) -> str: + return f"PsBuffer({self._name}: {self.element_type}[{len(self.shape)}D])" diff --git a/src/pystencils/backend/symbols.py b/src/pystencils/backend/symbols.py deleted file mode 100644 index b007e3fcf4791e89bb34829f0c6e4b7d1dbbbd21..0000000000000000000000000000000000000000 --- a/src/pystencils/backend/symbols.py +++ /dev/null @@ -1,55 +0,0 @@ -from ..types import PsType, PsTypeError -from .exceptions import PsInternalCompilerError - - -class PsSymbol: - """A mutable symbol with name and data type. - - Do not create objects of this class directly unless you know what you are doing; - instead obtain them from a `KernelCreationContext` through `KernelCreationContext.get_symbol`. - This way, the context can keep track of all symbols used in the translation run, - and uniqueness of symbols is ensured. - """ - - __match_args__ = ("name", "dtype") - - def __init__(self, name: str, dtype: PsType | None = None): - self._name = name - self._dtype = dtype - - @property - def name(self) -> str: - return self._name - - @property - def dtype(self) -> PsType | None: - return self._dtype - - @dtype.setter - def dtype(self, value: PsType): - self._dtype = value - - def apply_dtype(self, dtype: PsType): - """Apply the given data type to this symbol, - raising a TypeError if it conflicts with a previously set data type.""" - - if self._dtype is not None and self._dtype != dtype: - raise PsTypeError( - f"Incompatible symbol data types: {self._dtype} and {dtype}" - ) - - self._dtype = dtype - - def get_dtype(self) -> PsType: - if self._dtype is None: - raise PsInternalCompilerError( - f"Symbol {self.name} had no type assigned yet" - ) - return self._dtype - - def __str__(self) -> str: - dtype_str = "<untyped>" if self._dtype is None else str(self._dtype) - return f"{self._name}: {dtype_str}" - - def __repr__(self) -> str: - return f"PsSymbol({self._name}, {self._dtype})" diff --git a/src/pystencils/backend/transformations/canonical_clone.py b/src/pystencils/backend/transformations/canonical_clone.py index b21fd115f98645ff4c8dfb2dd3f72c252282fcf2..2cf9bcf0c8d95f16e9935ed1754f853f13516cc8 100644 --- a/src/pystencils/backend/transformations/canonical_clone.py +++ b/src/pystencils/backend/transformations/canonical_clone.py @@ -1,7 +1,7 @@ from typing import TypeVar, cast from ..kernelcreation import KernelCreationContext -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..exceptions import PsInternalCompilerError from ..ast import PsAstNode diff --git a/src/pystencils/backend/transformations/canonicalize_symbols.py b/src/pystencils/backend/transformations/canonicalize_symbols.py index e55807ef4bd7726d173b6028997778538b66053b..f5b356432a56cc8c2a33eba6ad533947b9f2b2ad 100644 --- a/src/pystencils/backend/transformations/canonicalize_symbols.py +++ b/src/pystencils/backend/transformations/canonicalize_symbols.py @@ -1,5 +1,5 @@ from ..kernelcreation import KernelCreationContext -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..exceptions import PsInternalCompilerError from ..ast import PsAstNode diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index bd3b2bb5871b59dd1e0431ed3d5ee7d26c4ff7a6..222f4a378c3063cf58e1d14f714a5ccbdc524964 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -34,7 +34,7 @@ from ..ast.expressions import ( from ..ast.util import AstEqWrapper from ..constants import PsConstant -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..functions import PsMathFunction from ...types import ( PsIntegerType, diff --git a/src/pystencils/backend/transformations/erase_anonymous_structs.py b/src/pystencils/backend/transformations/erase_anonymous_structs.py index 7404abd94771c7978d2b66124ffe3f3d64114319..08fd6bfa59bec56baa8a8206ee8bd3cd1d3882af 100644 --- a/src/pystencils/backend/transformations/erase_anonymous_structs.py +++ b/src/pystencils/backend/transformations/erase_anonymous_structs.py @@ -13,7 +13,6 @@ from ..ast.expressions import ( PsCast, ) from ..kernelcreation import Typifier -from ..arrays import PsArrayBasePointer, TypeErasedBasePointer from ...types import PsStructType, PsPointerType diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index d4dfd3d0494967273a4010e39385ceaec2ee29be..c10a696ab7ae5435b30f3ecf5c707df112d8956a 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -19,7 +19,7 @@ from ..ast.expressions import ( from ..ast.util import determine_memory_object from ...types import PsDereferencableType -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..functions import PsMathFunction __all__ = ["HoistLoopInvariantDeclarations"] diff --git a/tests/nbackend/kernelcreation/test_context.py b/tests/nbackend/kernelcreation/test_context.py index 9701013b000c26446649c82e6c96c5a64861f76c..ff766e6b5b47a0cd54a337b95a4c1d7feb4a9243 100644 --- a/tests/nbackend/kernelcreation/test_context.py +++ b/tests/nbackend/kernelcreation/test_context.py @@ -5,6 +5,7 @@ from pystencils import Field, TypedSymbol, FieldType, DynamicType from pystencils.backend.kernelcreation import KernelCreationContext from pystencils.backend.constants import PsConstant +from pystencils.backend.memory import PsSymbol, FieldShape, FieldStride from pystencils.backend.exceptions import KernelConstraintsError from pystencils.types.quick import SInt, Fp from pystencils.types import deconstify @@ -14,7 +15,7 @@ def test_field_arrays(): ctx = KernelCreationContext(index_dtype=SInt(16)) f = Field.create_generic("f", 3, Fp(32)) - f_arr = ctx.get_array(f) + f_arr = ctx.get_buffer(f) assert f_arr.element_type == f.dtype == Fp(32) assert len(f_arr.shape) == len(f.shape) + 1 == 4 @@ -23,9 +24,17 @@ def test_field_arrays(): assert f_arr.index_type == ctx.index_dtype == SInt(16) assert f_arr.shape[0].dtype == ctx.index_dtype == SInt(16) + for i, s in enumerate(f_arr.shape[:1]): + assert isinstance(s, PsSymbol) + assert FieldShape(f, i) in s.properties + + for i, s in enumerate(f_arr.strides[:1]): + assert isinstance(s, PsSymbol) + assert FieldStride(f, i) in s.properties + g = Field.create_generic("g", 3, index_shape=(2, 4), dtype=Fp(16)) - g_arr = ctx.get_array(g) - + g_arr = ctx.get_buffer(g) + assert g_arr.element_type == g.dtype == Fp(16) assert len(g_arr.shape) == len(g.spatial_shape) + len(g.index_shape) == 5 assert isinstance(g_arr.shape[3], PsConstant) and g_arr.shape[3].value == 2 @@ -39,26 +48,23 @@ def test_field_arrays(): FieldType.GENERIC, Fp(32), (0, 1), - ( - TypedSymbol("nx", SInt(32)), - TypedSymbol("ny", SInt(32)), - 1 - ), - ( - TypedSymbol("sx", SInt(32)), - TypedSymbol("sy", SInt(32)), - 1 - ) - ) - - h_arr = ctx.get_array(h) + (TypedSymbol("nx", SInt(32)), TypedSymbol("ny", SInt(32)), 1), + (TypedSymbol("sx", SInt(32)), TypedSymbol("sy", SInt(32)), 1), + ) + + h_arr = ctx.get_buffer(h) assert h_arr.index_type == SInt(32) - + for s in chain(h_arr.shape, h_arr.strides): assert deconstify(s.get_dtype()) == SInt(32) - assert [s.name for s in chain(h_arr.shape[:2], h_arr.strides[:2])] == ["nx", "ny", "sx", "sy"] + assert [s.name for s in chain(h_arr.shape[:2], h_arr.strides[:2])] == [ + "nx", + "ny", + "sx", + "sy", + ] def test_invalid_fields(): @@ -70,11 +76,11 @@ def test_invalid_fields(): Fp(32), (0,), (TypedSymbol("nx", SInt(32)),), - (TypedSymbol("sx", SInt(64)),) + (TypedSymbol("sx", SInt(64)),), ) - + with pytest.raises(KernelConstraintsError): - _ = ctx.get_array(h) + _ = ctx.get_buffer(h) h = Field( "h", @@ -82,11 +88,11 @@ def test_invalid_fields(): Fp(32), (0,), (TypedSymbol("nx", Fp(32)),), - (TypedSymbol("sx", Fp(32)),) + (TypedSymbol("sx", Fp(32)),), ) - + with pytest.raises(KernelConstraintsError): - _ = ctx.get_array(h) + _ = ctx.get_buffer(h) h = Field( "h", @@ -94,8 +100,39 @@ def test_invalid_fields(): Fp(32), (0,), (TypedSymbol("nx", DynamicType.NUMERIC_TYPE),), - (TypedSymbol("sx", DynamicType.NUMERIC_TYPE),) + (TypedSymbol("sx", DynamicType.NUMERIC_TYPE),), ) - + with pytest.raises(KernelConstraintsError): - _ = ctx.get_array(h) + _ = ctx.get_buffer(h) + + +def test_duplicate_fields(): + f = Field.create_generic("f", 3) + g = f.new_field_with_different_name("g") + + # f and g have the same indexing symbols + assert f.shape == g.shape + assert f.strides == g.strides + + ctx = KernelCreationContext() + + f_buf = ctx.get_buffer(f) + g_buf = ctx.get_buffer(g) + + for sf, sg in zip(chain(f_buf.shape, f_buf.strides), chain(g_buf.shape, g_buf.strides)): + # Must be the same + assert sf == sg + + for i, s in enumerate(f_buf.shape[:-1]): + assert isinstance(s, PsSymbol) + assert FieldShape(f, i) in s.properties + assert FieldShape(g, i) in s.properties + + for i, s in enumerate(f_buf.strides[:-1]): + assert isinstance(s, PsSymbol) + assert FieldStride(f, i) in s.properties + assert FieldStride(g, i) in s.properties + + # Base pointers must be different, though! + assert f_buf.base_pointer != g_buf.base_pointer diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 341b7560195a4f35cf88db056438fb987bbce8e6..65bf57e787c8cecdc6f1d279a5ca09fc374a07e4 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -106,8 +106,8 @@ def test_freeze_fields(): f, g = fields("f, g : [1D]") asm = Assignment(f.center(0), g.center(0)) - f_arr = ctx.get_array(f) - g_arr = ctx.get_array(g) + f_arr = ctx.get_buffer(f) + g_arr = ctx.get_buffer(g) fasm = freeze(asm) diff --git a/tests/nbackend/kernelcreation/test_iteration_space.py b/tests/nbackend/kernelcreation/test_iteration_space.py index 8ff678fbad398af18b71f2f30145ff34ab269685..5d56abd2b818fa74fbd48aac0216d472112f8c64 100644 --- a/tests/nbackend/kernelcreation/test_iteration_space.py +++ b/tests/nbackend/kernelcreation/test_iteration_space.py @@ -23,7 +23,7 @@ def test_slices_over_field(): islice = (slice(1, -1, 1), slice(3, -3, 3), slice(0, None, 1)) ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) - archetype_arr = ctx.get_array(archetype_field) + archetype_arr = ctx.get_buffer(archetype_field) dims = ispace.dimensions @@ -58,7 +58,7 @@ def test_slices_with_fixed_size_field(): islice = (slice(1, -1, 1), slice(3, -3, 3), slice(0, None, 1)) ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) - archetype_arr = ctx.get_array(archetype_field) + archetype_arr = ctx.get_buffer(archetype_field) dims = ispace.dimensions @@ -87,7 +87,7 @@ def test_singular_slice_over_field(): archetype_field = Field.create_generic("f", spatial_dimensions=2, layout="fzyx") ctx.add_field(archetype_field) - archetype_arr = ctx.get_array(archetype_field) + archetype_arr = ctx.get_buffer(archetype_field) islice = (4, -3) ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) @@ -113,7 +113,7 @@ def test_slices_with_negative_start(): archetype_field = Field.create_generic("f", spatial_dimensions=2, layout="fzyx") ctx.add_field(archetype_field) - archetype_arr = ctx.get_array(archetype_field) + archetype_arr = ctx.get_buffer(archetype_field) islice = (slice(-3, -1, 1), slice(-4, None, 1)) ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) diff --git a/tests/nbackend/test_ast.py b/tests/nbackend/test_ast.py index 02f03bfa9fe62b3a184ff986f82037d53e866b45..88fdd3c8d1e85bc8365bed7268882632aa904891 100644 --- a/tests/nbackend/test_ast.py +++ b/tests/nbackend/test_ast.py @@ -1,4 +1,4 @@ -from pystencils.backend.symbols import PsSymbol +from pystencils.backend.memory import PsSymbol from pystencils.backend.constants import PsConstant from pystencils.backend.ast.expressions import ( PsExpression, diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py index dc7a86b0bdff8d73b58f260e9f6d3e03ddba3fd1..ef4806314eb52c7389bf583027bb808c42049213 100644 --- a/tests/nbackend/test_code_printing.py +++ b/tests/nbackend/test_code_printing.py @@ -3,10 +3,9 @@ from pystencils import Target from pystencils.backend.ast.expressions import PsExpression from pystencils.backend.ast.structural import PsAssignment, PsLoop, PsBlock from pystencils.backend.kernelfunction import KernelFunction -from pystencils.backend.symbols import PsSymbol +from pystencils.backend.memory import PsSymbol, PsBuffer from pystencils.backend.constants import PsConstant from pystencils.backend.literals import PsLiteral -from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer from pystencils.types.quick import Fp, SInt, UInt, Bool from pystencils.backend.emission import CAstPrinter diff --git a/tests/nbackend/test_cpujit.py b/tests/nbackend/test_cpujit.py index b621829ad7e72383ec6651015b7813e0a009839b..dc321848645e65e62c605b6110951534e678492b 100644 --- a/tests/nbackend/test_cpujit.py +++ b/tests/nbackend/test_cpujit.py @@ -3,9 +3,8 @@ import pytest from pystencils import Target # from pystencils.backend.constraints import PsKernelParamsConstraint -from pystencils.backend.symbols import PsSymbol +from pystencils.backend.memory import PsSymbol, PsBuffer from pystencils.backend.constants import PsConstant -from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer from pystencils.backend.ast.expressions import PsArrayAccess, PsExpression from pystencils.backend.ast.structural import PsAssignment, PsBlock, PsLoop @@ -21,8 +20,8 @@ import numpy as np def test_pairwise_addition(): idx_type = SInt(64) - u = PsLinearizedArray("u", Fp(64, const=True), (...,), (...,), index_dtype=idx_type) - v = PsLinearizedArray("v", Fp(64), (...,), (...,), index_dtype=idx_type) + u = PsBuffer("u", Fp(64, const=True), (...,), (...,), index_dtype=idx_type) + v = PsBuffer("v", Fp(64), (...,), (...,), index_dtype=idx_type) u_data = PsArrayBasePointer("u_data", u) v_data = PsArrayBasePointer("v_data", v) diff --git a/tests/nbackend/test_memory.py b/tests/nbackend/test_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2ab340e5c5e8c4a358d4f9ee75a8d73b297c14 --- /dev/null +++ b/tests/nbackend/test_memory.py @@ -0,0 +1,52 @@ +import pytest + +from typing import ClassVar +from dataclasses import dataclass +from pystencils.backend.memory import PsSymbol, PsSymbolProperty + + +def test_properties(): + @dataclass(frozen=True) + class NumbersProperty(PsSymbolProperty): + n: int + x: float + + @dataclass(frozen=True) + class StringProperty(PsSymbolProperty): + s: str + + @dataclass(frozen=True) + class UniqueProperty(PsSymbolProperty): + val: int + _unique: ClassVar[bool] = True + + s = PsSymbol("s") + + assert not s.properties + + s.add_property(NumbersProperty(42, 8.71)) + assert s.properties == {NumbersProperty(42, 8.71)} + + # no duplicates + s.add_property(NumbersProperty(42, 8.71)) + assert s.properties == {NumbersProperty(42, 8.71)} + + s.add_property(StringProperty("pystencils")) + assert s.properties == {NumbersProperty(42, 8.71), StringProperty("pystencils")} + + assert s.get_properties(NumbersProperty) == {NumbersProperty(42, 8.71)} + + assert not s.get_properties(UniqueProperty) + + s.add_property(UniqueProperty(13)) + assert s.get_properties(UniqueProperty) == {UniqueProperty(13)} + + # Adding the same one again does not raise + s.add_property(UniqueProperty(13)) + assert s.get_properties(UniqueProperty) == {UniqueProperty(13)} + + with pytest.raises(ValueError): + s.add_property(UniqueProperty(14)) + + s.remove_property(UniqueProperty(13)) + assert not s.get_properties(UniqueProperty) diff --git a/tests/nbackend/transformations/test_canonicalize_symbols.py b/tests/nbackend/transformations/test_canonicalize_symbols.py index 6a478556469b6409507ae7b4ccc545311ab38462..a11e9bd1353ef98c5ae23e0d86ddce3e5c9d579c 100644 --- a/tests/nbackend/transformations/test_canonicalize_symbols.py +++ b/tests/nbackend/transformations/test_canonicalize_symbols.py @@ -52,7 +52,7 @@ def test_deduplication(): assert canonicalize.get_last_live_symbols() == { ctx.find_symbol("y"), ctx.find_symbol("z"), - ctx.get_array(f).base_pointer, + ctx.get_buffer(f).base_pointer, } assert ctx.find_symbol("x") is not None diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py index 92bb5c947b4bc2e4b6d50064a9c07874cdda43cf..4c18970086b4537dec2ae974ddf6242da57b591e 100644 --- a/tests/nbackend/transformations/test_constant_elimination.py +++ b/tests/nbackend/transformations/test_constant_elimination.py @@ -1,6 +1,6 @@ from pystencils.backend.kernelcreation import KernelCreationContext, Typifier from pystencils.backend.ast.expressions import PsExpression, PsConstantExpr -from pystencils.backend.symbols import PsSymbol +from pystencils.backend.memory import PsSymbol from pystencils.backend.constants import PsConstant from pystencils.backend.transformations import EliminateConstants