diff --git a/pystencils/nbackend/arrays.py b/pystencils/nbackend/arrays.py new file mode 100644 index 0000000000000000000000000000000000000000..23f0ff7788e9464c8ed3860e78993b3ff4527d0f --- /dev/null +++ b/pystencils/nbackend/arrays.py @@ -0,0 +1,209 @@ +""" +Arrays +====== + +The pystencils backend models contiguous n-dimensional arrays using a number of classes. +Arrays themselves are represented through the `PsLinearizedArray` class. +An array has a fixed name, dimensionality, and element type, as well as a number of associated +variables. + +The associated variables are the *shape* and *strides* of the array, modelled by the +`PsArrayShapeVar` and `PsArrayStrideVar` classes. They have integer type and are used to +reason about the array's memory layout. + + +Memory Layout Constraints +------------------------- + +Initially, all memory layout information about an array is symbolic and unconstrained. +Several scenarios exist where memory layout must be constrained, e.g. certain pointers +need to be aligned, certain strides must be fixed or fulfill certain alignment properties, +or even the field shape must be fixed. + +The code generation backend models such requirements and assumptions as *constraints*. +Constraints are external to the arrays themselves. They are created by the AST passes which +require them and exposed through the `PsKernelFunction` class to the compiler kernel's runtime +environment. It is the responsibility of the runtime environment to fulfill all constraints. + +For example, if an array `arr` should have both a fixed shape and fixed strides, +an optimization pass will have to add equality constraints like the following before replacing +all occurences of the shape and stride variables with their constant value: + +``` +constraints = ( + [PsParamConstraint(s.eq(f)) for s, f in zip(arr.shape, fixed_size)] + + [PsParamConstraint(s.eq(f)) for s, f in zip(arr.strides, fixed_strides)] +) + +kernel_function.add_constraints(*constraints) +``` + +""" + + +from __future__ import annotations + +from typing import TYPE_CHECKING +from abc import ABC + +import pymbolic.primitives as pb + +from .types import ( + PsAbstractType, + PsScalarType, + PsPointerType, + PsIntegerType, + PsSignedIntegerType, + constify, +) + +if TYPE_CHECKING: + from .typed_expressions import PsTypedVariable, PsTypedConstant + + +class PsLinearizedArray: + """N-dimensional contiguous array""" + + def __init__( + self, + name: str, + element_type: PsScalarType, + dim: int, + offsets: tuple[int, ...] | None = None, + index_dtype: PsIntegerType = PsSignedIntegerType(64), + ): + self._name = name + + if offsets is not None and len(offsets) != dim: + raise ValueError(f"Must have exactly {dim} offsets.") + + self._shape = tuple( + PsArrayShapeVar(self, d, constify(index_dtype)) for d in range(dim) + ) + self._strides = tuple( + PsArrayStrideVar(self, d, constify(index_dtype)) for d in range(dim) + ) + self._element_type = element_type + + if offsets is None: + offsets = (0,) * dim + + self._offsets = tuple(PsTypedConstant(o, index_dtype) for o in offsets) + + @property + def name(self): + return self._name + + @property + def shape(self): + return self._shape + + @property + def strides(self): + return self._strides + + @property + def element_type(self): + return self._element_type + + @property + def offsets(self) -> tuple[PsTypedConstant, ...]: + return self._offsets + + +class PsArrayAssocVar(PsTypedVariable, ABC): + """A variable that is associated to an array. + + Instances of this class represent pointers and indexing information bound + to a particular array. + """ + + def __init__(self, name: str, dtype: PsAbstractType, array: PsLinearizedArray): + super().__init__(name, dtype) + self._array = array + + @property + def array(self) -> PsLinearizedArray: + return self._array + + +class PsArrayBasePointer(PsArrayAssocVar): + def __init__(self, name: str, array: PsLinearizedArray): + dtype = PsPointerType(array.element_type) + super().__init__(name, dtype, array) + + self._array = array + + +class PsArrayShapeVar(PsArrayAssocVar): + def __init__(self, array: PsLinearizedArray, dimension: int, dtype: PsIntegerType): + name = f"{array}_size{dimension}" + super().__init__(name, dtype, array) + + +class PsArrayStrideVar(PsArrayAssocVar): + def __init__(self, array: PsLinearizedArray, dimension: int, dtype: PsIntegerType): + name = f"{array}_size{dimension}" + super().__init__(name, dtype, array) + + +class PsArrayAccess(pb.Subscript): + def __init__(self, base_ptr: PsArrayBasePointer, index: pb.Expression): + super(PsArrayAccess, self).__init__(base_ptr, index) + self._base_ptr = base_ptr + self._index = index + + @property + def base_ptr(self): + return self._base_ptr + + @property + def array(self) -> PsLinearizedArray: + return self._base_ptr.array + + @property + def dtype(self) -> PsAbstractType: + """Data type of this expression, i.e. the element type of the underlying array""" + return self._base_ptr.array.element_type + + +# class PsIterationDomain: +# """A factory for arrays spanning a given iteration domain.""" + +# def __init__( +# self, +# id: str, +# dim: int | None = None, +# fixed_shape: tuple[int, ...] | None = None, +# index_dtype: PsIntegerType = PsSignedIntegerType(64), +# ): +# if fixed_shape is not None: +# if dim is not None and len(fixed_shape) != dim: +# raise ValueError( +# "If both `dim` and `fixed_shape` are specified, `fixed_shape` must have exactly `dim` entries." +# ) + +# shape = tuple(PsTypedConstant(s, index_dtype) for s in fixed_shape) +# elif dim is not None: +# shape = tuple( +# PsTypedVariable(f"{id}_shape_{d}", index_dtype) for d in range(dim) +# ) +# else: +# raise ValueError("Either `fixed_shape` or `dim` must be specified.") + +# self._domain_shape: tuple[VarOrConstant, ...] = shape +# self._index_dtype = index_dtype + +# self._archetype_array: PsLinearizedArray | None = None + +# self._constraints: list[PsParamConstraint] = [] + +# @property +# def dim(self) -> int: +# return len(self._domain_shape) + +# @property +# def shape(self) -> tuple[VarOrConstant, ...]: +# return self._domain_shape + +# def create_array(self, ghost_layers: int = 0): diff --git a/pystencils/nbackend/ast/constraints.py b/pystencils/nbackend/ast/constraints.py new file mode 100644 index 0000000000000000000000000000000000000000..68cbe347a3acc7c5716da86277f952323e26f8ae --- /dev/null +++ b/pystencils/nbackend/ast/constraints.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + +import pymbolic.primitives as pb +from pymbolic.mapper.c_code import CCodeMapper + + +@dataclass +class PsParamConstraint: + condition: pb.Comparison + message: str = "" + + def print(self): + return CCodeMapper()(self.condition) diff --git a/pystencils/nbackend/ast/kernelfunction.py b/pystencils/nbackend/ast/kernelfunction.py index aaf1ac5e51c30277d1b0f64ca7ccfcde28145339..fccd4f12a18292a53566ad3a7d93eef6fef637f4 100644 --- a/pystencils/nbackend/ast/kernelfunction.py +++ b/pystencils/nbackend/ast/kernelfunction.py @@ -1,13 +1,68 @@ -from typing import Sequence +from __future__ import annotations from typing import Generator +from dataclasses import dataclass + +from pymbolic.mapper.dependency import DependencyMapper + from .nodes import PsAstNode, PsBlock, failing_cast +from .constraints import PsParamConstraint from ..typed_expressions import PsTypedVariable +from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocVar +from ..exceptions import PsInternalCompilerError from ...enums import Target +@dataclass +class PsKernelParametersSpec: + """Specification of a kernel function's parameters. + + Contains: + - Verbatim parameter list, a list of `PsTypedVariables` + - List of Arrays used in the kernel, in canonical order + - A set of constraints on the kernel parameters, used to e.g. express relations of array + shapes, alignment properties, ... + """ + + params: tuple[PsTypedVariable, ...] + arrays: tuple[PsLinearizedArray, ...] + constraints: tuple[PsParamConstraint, ...] + + def params_for_array(self, arr: PsLinearizedArray): + def pred(p: PsTypedVariable): + return isinstance(p, PsArrayAssocVar) and p.array == arr + + return tuple(filter(pred, self.params)) + + def __post_init__(self): + dep_mapper = DependencyMapper(False, False, False, False) + + # Check constraints + for constraint in self.constraints: + variables: set[PsTypedVariable] = dep_mapper(constraint.condition) + for var in variables: + if isinstance(var, PsArrayAssocVar): + if var.array in self.arrays: + continue + + elif var in self.params: + continue + + else: + raise PsInternalCompilerError( + "Constrained parameter was neither contained in kernel parameter list " + "nor associated with a kernel array.\n" + f" Parameter: {var}\n" + f" Constraint: {constraint.condition}" + ) + + class PsKernelFunction(PsAstNode): - """A complete pystencils kernel function.""" + """A pystencils kernel function. + + Objects of this class represent a full pystencils kernel and should provide all information required for + export, compilation, and inclusion of the kernel into a runtime system. + """ __match_args__ = ("body",) @@ -16,6 +71,8 @@ class PsKernelFunction(PsAstNode): self._target = target self._name = name + self._constraints: list[PsParamConstraint] = [] + @property def target(self) -> Target: """See pystencils.Target""" @@ -53,7 +110,10 @@ class PsKernelFunction(PsAstNode): raise IndexError(f"Child index out of bounds: {idx}") self._body = failing_cast(PsBlock, c) - def get_parameters(self) -> Sequence[PsTypedVariable]: + def add_constraints(self, *constraints: PsParamConstraint): + self._constraints += constraints + + def get_parameters(self) -> PsKernelParametersSpec: """Collect the list of parameters to this function. This function performs a full traversal of the AST. @@ -61,5 +121,8 @@ class PsKernelFunction(PsAstNode): """ from .analysis import UndefinedVariablesCollector - params = UndefinedVariablesCollector().collect(self) - return sorted(params, key=lambda p: p.name) + params_set = UndefinedVariablesCollector().collect(self) + params_list = sorted(params_set, key=lambda p: p.name) + + arrays = set(p.array for p in params_list if isinstance(p, PsArrayBasePointer)) + return PsKernelParametersSpec(tuple(params_list), tuple(arrays), tuple(self._constraints)) diff --git a/pystencils/nbackend/ast/nodes.py b/pystencils/nbackend/ast/nodes.py index da3ead273ce14a80c0874630aefebf5cf2835d80..2944e073dd97549fd7087013eb84b911640010c8 100644 --- a/pystencils/nbackend/ast/nodes.py +++ b/pystencils/nbackend/ast/nodes.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Sequence, Generator, Iterable, cast +from typing import Sequence, Generator, Iterable, cast, TypeAlias from abc import ABC, abstractmethod -from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue, ExprOrConstant +from ..typed_expressions import PsTypedVariable, ExprOrConstant +from ..arrays import PsArrayAccess from .util import failing_cast @@ -123,6 +124,10 @@ class PsSymbolExpr(PsLvalueExpr): self._expr = symbol +PsLvalue: TypeAlias = PsTypedVariable | PsArrayAccess +"""Types of expressions that may occur on the left-hand side of assignments.""" + + class PsAssignment(PsAstNode): __match_args__ = ( "lhs", diff --git a/pystencils/nbackend/c_printer.py b/pystencils/nbackend/c_printer.py index 6872ad9dfcf2fb9981bdaa4f3aca901c4007c63b..58a61d579991f6fa2c51ee2300b7941bf584ed0e 100644 --- a/pystencils/nbackend/c_printer.py +++ b/pystencils/nbackend/c_printer.py @@ -27,8 +27,8 @@ class CPrinter: @visit.case(PsKernelFunction) def function(self, func: PsKernelFunction) -> str: - params = func.get_parameters() - params_str = ", ".join(f"{p.dtype} {p.name}" for p in params) + params_spec = func.get_parameters() + params_str = ", ".join(f"{p.dtype} {p.name}" for p in params_spec.params) decl = f"FUNC_PREFIX void {func.name} ({params_str})" body = self.visit(func.body) return f"{decl}\n{body}" diff --git a/pystencils/nbackend/sympy_mapper.py b/pystencils/nbackend/sympy_mapper.py index 380ed699d285141feb7028a05d2e1ddce1059fae..fecad4c63bb316583b98bc019407e9cfc107b206 100644 --- a/pystencils/nbackend/sympy_mapper.py +++ b/pystencils/nbackend/sympy_mapper.py @@ -4,7 +4,8 @@ from pystencils.typing import TypedSymbol from pystencils.typing.typed_sympy import SHAPE_DTYPE from .ast.nodes import PsAssignment, PsSymbolExpr from .types import PsSignedIntegerType, PsIeeeFloatType, PsUnsignedIntegerType -from .typed_expressions import PsArrayBasePointer, PsLinearizedArray, PsTypedVariable, PsArrayAccess +from .typed_expressions import PsTypedVariable +from .arrays import PsArrayBasePointer, PsLinearizedArray, PsArrayAccess CTR_SYMBOLS = [TypedSymbol(f"ctr_{i}", SHAPE_DTYPE) for i in range(3)] @@ -44,7 +45,9 @@ class PystencilsToPymbolicMapper(SympyToPymbolicMapper): array = PsLinearizedArray(name, shape, strides, dtype) ptr = PsArrayBasePointer(expr.name, array) - index = sum([ctr * stride for ctr, stride in zip(CTR_SYMBOLS, expr.field.strides)]) + index = sum( + [ctr * stride for ctr, stride in zip(CTR_SYMBOLS, expr.field.strides)] + ) index = self.rec(index) return PsSymbolExpr(PsArrayAccess(ptr, index)) diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py index ee3536a52e3b2ffc3bf2908c3f3254348020e94d..e878fbfcee0b03f4444382672a624fad10f5679e 100644 --- a/pystencils/nbackend/typed_expressions.py +++ b/pystencils/nbackend/typed_expressions.py @@ -1,15 +1,12 @@ from __future__ import annotations -from functools import reduce -from typing import TypeAlias, Union, Any, Tuple +from typing import TypeAlias, Any import pymbolic.primitives as pb from .types import ( PsAbstractType, - PsScalarType, PsNumericType, - PsPointerType, constify, PsTypeError, ) @@ -25,91 +22,6 @@ class PsTypedVariable(pb.Variable): return self._dtype -class PsArray: - def __init__( - self, - name: str, - length: pb.Expression, - element_type: PsScalarType, # todo Frederik: is PsScalarType correct? - ): - self._name = name - self._length = length - self._element_type = element_type - - @property - def name(self): - return self._name - - @property - def length(self): - return self._length - - @property - def element_type(self): - return self._element_type - - -class PsLinearizedArray(PsArray): - """N-dimensional contiguous array""" - - def __init__( - self, - name: str, - shape: Tuple[pb.Expression, ...], - strides: Tuple[pb.Expression], - element_type: PsScalarType, - ): - length = reduce(lambda x, y: x * y, shape) - super().__init__(name, length, element_type) - - self._shape = shape - self._strides = strides - - @property - def shape(self): - return self._shape - - @property - def strides(self): - return self._strides - - -class PsArrayBasePointer(PsTypedVariable): - def __init__(self, name: str, array: PsArray): - dtype = PsPointerType(array.element_type) - super().__init__(name, dtype) - - self._array = array - - @property - def array(self): - return self._array - - -class PsArrayAccess(pb.Subscript): - def __init__(self, base_ptr: PsArrayBasePointer, index: pb.Expression): - super(PsArrayAccess, self).__init__(base_ptr, index) - self._base_ptr = base_ptr - self._index = index - - @property - def base_ptr(self): - return self._base_ptr - - # @property - # def index(self): - # return self._index - - @property - def array(self) -> PsArray: - return self._base_ptr.array - - @property - def dtype(self) -> PsAbstractType: - """Data type of this expression, i.e. the element type of the underlying array""" - return self._base_ptr.array.element_type - - class PsTypedConstant: """Represents typed constants occuring in the pystencils AST. @@ -290,9 +202,7 @@ class PsTypedConstant: pb.register_constant_class(PsTypedConstant) - -PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess] -"""Types of expressions that may occur on the left-hand side of assignments.""" - ExprOrConstant: TypeAlias = pb.Expression | PsTypedConstant """Required since `PsTypedConstant` does not derive from `pb.Expression`.""" + +VarOrConstant: TypeAlias = PsTypedVariable | PsTypedConstant