Skip to content
Snippets Groups Projects
Commit ec0f31e8 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

array refactoring, constraints

parent 20d37fba
No related branches found
No related tags found
No related merge requests found
"""
Arrays
======
The pystencils backend models contiguous n-dimensional arrays using a number of classes.
Arrays themselves are represented through the `PsLinearizedArray` class.
An array has a fixed name, dimensionality, and element type, as well as a number of associated
variables.
The associated variables are the *shape* and *strides* of the array, modelled by the
`PsArrayShapeVar` and `PsArrayStrideVar` classes. They have integer type and are used to
reason about the array's memory layout.
Memory Layout Constraints
-------------------------
Initially, all memory layout information about an array is symbolic and unconstrained.
Several scenarios exist where memory layout must be constrained, e.g. certain pointers
need to be aligned, certain strides must be fixed or fulfill certain alignment properties,
or even the field shape must be fixed.
The code generation backend models such requirements and assumptions as *constraints*.
Constraints are external to the arrays themselves. They are created by the AST passes which
require them and exposed through the `PsKernelFunction` class to the compiler kernel's runtime
environment. It is the responsibility of the runtime environment to fulfill all constraints.
For example, if an array `arr` should have both a fixed shape and fixed strides,
an optimization pass will have to add equality constraints like the following before replacing
all occurences of the shape and stride variables with their constant value:
```
constraints = (
[PsParamConstraint(s.eq(f)) for s, f in zip(arr.shape, fixed_size)]
+ [PsParamConstraint(s.eq(f)) for s, f in zip(arr.strides, fixed_strides)]
)
kernel_function.add_constraints(*constraints)
```
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from abc import ABC
import pymbolic.primitives as pb
from .types import (
PsAbstractType,
PsScalarType,
PsPointerType,
PsIntegerType,
PsSignedIntegerType,
constify,
)
if TYPE_CHECKING:
from .typed_expressions import PsTypedVariable, PsTypedConstant
class PsLinearizedArray:
"""N-dimensional contiguous array"""
def __init__(
self,
name: str,
element_type: PsScalarType,
dim: int,
offsets: tuple[int, ...] | None = None,
index_dtype: PsIntegerType = PsSignedIntegerType(64),
):
self._name = name
if offsets is not None and len(offsets) != dim:
raise ValueError(f"Must have exactly {dim} offsets.")
self._shape = tuple(
PsArrayShapeVar(self, d, constify(index_dtype)) for d in range(dim)
)
self._strides = tuple(
PsArrayStrideVar(self, d, constify(index_dtype)) for d in range(dim)
)
self._element_type = element_type
if offsets is None:
offsets = (0,) * dim
self._offsets = tuple(PsTypedConstant(o, index_dtype) for o in offsets)
@property
def name(self):
return self._name
@property
def shape(self):
return self._shape
@property
def strides(self):
return self._strides
@property
def element_type(self):
return self._element_type
@property
def offsets(self) -> tuple[PsTypedConstant, ...]:
return self._offsets
class PsArrayAssocVar(PsTypedVariable, ABC):
"""A variable that is associated to an array.
Instances of this class represent pointers and indexing information bound
to a particular array.
"""
def __init__(self, name: str, dtype: PsAbstractType, array: PsLinearizedArray):
super().__init__(name, dtype)
self._array = array
@property
def array(self) -> PsLinearizedArray:
return self._array
class PsArrayBasePointer(PsArrayAssocVar):
def __init__(self, name: str, array: PsLinearizedArray):
dtype = PsPointerType(array.element_type)
super().__init__(name, dtype, array)
self._array = array
class PsArrayShapeVar(PsArrayAssocVar):
def __init__(self, array: PsLinearizedArray, dimension: int, dtype: PsIntegerType):
name = f"{array}_size{dimension}"
super().__init__(name, dtype, array)
class PsArrayStrideVar(PsArrayAssocVar):
def __init__(self, array: PsLinearizedArray, dimension: int, dtype: PsIntegerType):
name = f"{array}_size{dimension}"
super().__init__(name, dtype, array)
class PsArrayAccess(pb.Subscript):
def __init__(self, base_ptr: PsArrayBasePointer, index: pb.Expression):
super(PsArrayAccess, self).__init__(base_ptr, index)
self._base_ptr = base_ptr
self._index = index
@property
def base_ptr(self):
return self._base_ptr
@property
def array(self) -> PsLinearizedArray:
return self._base_ptr.array
@property
def dtype(self) -> PsAbstractType:
"""Data type of this expression, i.e. the element type of the underlying array"""
return self._base_ptr.array.element_type
# class PsIterationDomain:
# """A factory for arrays spanning a given iteration domain."""
# def __init__(
# self,
# id: str,
# dim: int | None = None,
# fixed_shape: tuple[int, ...] | None = None,
# index_dtype: PsIntegerType = PsSignedIntegerType(64),
# ):
# if fixed_shape is not None:
# if dim is not None and len(fixed_shape) != dim:
# raise ValueError(
# "If both `dim` and `fixed_shape` are specified, `fixed_shape` must have exactly `dim` entries."
# )
# shape = tuple(PsTypedConstant(s, index_dtype) for s in fixed_shape)
# elif dim is not None:
# shape = tuple(
# PsTypedVariable(f"{id}_shape_{d}", index_dtype) for d in range(dim)
# )
# else:
# raise ValueError("Either `fixed_shape` or `dim` must be specified.")
# self._domain_shape: tuple[VarOrConstant, ...] = shape
# self._index_dtype = index_dtype
# self._archetype_array: PsLinearizedArray | None = None
# self._constraints: list[PsParamConstraint] = []
# @property
# def dim(self) -> int:
# return len(self._domain_shape)
# @property
# def shape(self) -> tuple[VarOrConstant, ...]:
# return self._domain_shape
# def create_array(self, ghost_layers: int = 0):
from dataclasses import dataclass
import pymbolic.primitives as pb
from pymbolic.mapper.c_code import CCodeMapper
@dataclass
class PsParamConstraint:
condition: pb.Comparison
message: str = ""
def print(self):
return CCodeMapper()(self.condition)
from typing import Sequence
from __future__ import annotations
from typing import Generator
from dataclasses import dataclass
from pymbolic.mapper.dependency import DependencyMapper
from .nodes import PsAstNode, PsBlock, failing_cast
from .constraints import PsParamConstraint
from ..typed_expressions import PsTypedVariable
from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocVar
from ..exceptions import PsInternalCompilerError
from ...enums import Target
@dataclass
class PsKernelParametersSpec:
"""Specification of a kernel function's parameters.
Contains:
- Verbatim parameter list, a list of `PsTypedVariables`
- List of Arrays used in the kernel, in canonical order
- A set of constraints on the kernel parameters, used to e.g. express relations of array
shapes, alignment properties, ...
"""
params: tuple[PsTypedVariable, ...]
arrays: tuple[PsLinearizedArray, ...]
constraints: tuple[PsParamConstraint, ...]
def params_for_array(self, arr: PsLinearizedArray):
def pred(p: PsTypedVariable):
return isinstance(p, PsArrayAssocVar) and p.array == arr
return tuple(filter(pred, self.params))
def __post_init__(self):
dep_mapper = DependencyMapper(False, False, False, False)
# Check constraints
for constraint in self.constraints:
variables: set[PsTypedVariable] = dep_mapper(constraint.condition)
for var in variables:
if isinstance(var, PsArrayAssocVar):
if var.array in self.arrays:
continue
elif var in self.params:
continue
else:
raise PsInternalCompilerError(
"Constrained parameter was neither contained in kernel parameter list "
"nor associated with a kernel array.\n"
f" Parameter: {var}\n"
f" Constraint: {constraint.condition}"
)
class PsKernelFunction(PsAstNode):
"""A complete pystencils kernel function."""
"""A pystencils kernel function.
Objects of this class represent a full pystencils kernel and should provide all information required for
export, compilation, and inclusion of the kernel into a runtime system.
"""
__match_args__ = ("body",)
......@@ -16,6 +71,8 @@ class PsKernelFunction(PsAstNode):
self._target = target
self._name = name
self._constraints: list[PsParamConstraint] = []
@property
def target(self) -> Target:
"""See pystencils.Target"""
......@@ -53,7 +110,10 @@ class PsKernelFunction(PsAstNode):
raise IndexError(f"Child index out of bounds: {idx}")
self._body = failing_cast(PsBlock, c)
def get_parameters(self) -> Sequence[PsTypedVariable]:
def add_constraints(self, *constraints: PsParamConstraint):
self._constraints += constraints
def get_parameters(self) -> PsKernelParametersSpec:
"""Collect the list of parameters to this function.
This function performs a full traversal of the AST.
......@@ -61,5 +121,8 @@ class PsKernelFunction(PsAstNode):
"""
from .analysis import UndefinedVariablesCollector
params = UndefinedVariablesCollector().collect(self)
return sorted(params, key=lambda p: p.name)
params_set = UndefinedVariablesCollector().collect(self)
params_list = sorted(params_set, key=lambda p: p.name)
arrays = set(p.array for p in params_list if isinstance(p, PsArrayBasePointer))
return PsKernelParametersSpec(tuple(params_list), tuple(arrays), tuple(self._constraints))
from __future__ import annotations
from typing import Sequence, Generator, Iterable, cast
from typing import Sequence, Generator, Iterable, cast, TypeAlias
from abc import ABC, abstractmethod
from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue, ExprOrConstant
from ..typed_expressions import PsTypedVariable, ExprOrConstant
from ..arrays import PsArrayAccess
from .util import failing_cast
......@@ -123,6 +124,10 @@ class PsSymbolExpr(PsLvalueExpr):
self._expr = symbol
PsLvalue: TypeAlias = PsTypedVariable | PsArrayAccess
"""Types of expressions that may occur on the left-hand side of assignments."""
class PsAssignment(PsAstNode):
__match_args__ = (
"lhs",
......
......@@ -27,8 +27,8 @@ class CPrinter:
@visit.case(PsKernelFunction)
def function(self, func: PsKernelFunction) -> str:
params = func.get_parameters()
params_str = ", ".join(f"{p.dtype} {p.name}" for p in params)
params_spec = func.get_parameters()
params_str = ", ".join(f"{p.dtype} {p.name}" for p in params_spec.params)
decl = f"FUNC_PREFIX void {func.name} ({params_str})"
body = self.visit(func.body)
return f"{decl}\n{body}"
......
......@@ -4,7 +4,8 @@ from pystencils.typing import TypedSymbol
from pystencils.typing.typed_sympy import SHAPE_DTYPE
from .ast.nodes import PsAssignment, PsSymbolExpr
from .types import PsSignedIntegerType, PsIeeeFloatType, PsUnsignedIntegerType
from .typed_expressions import PsArrayBasePointer, PsLinearizedArray, PsTypedVariable, PsArrayAccess
from .typed_expressions import PsTypedVariable
from .arrays import PsArrayBasePointer, PsLinearizedArray, PsArrayAccess
CTR_SYMBOLS = [TypedSymbol(f"ctr_{i}", SHAPE_DTYPE) for i in range(3)]
......@@ -44,7 +45,9 @@ class PystencilsToPymbolicMapper(SympyToPymbolicMapper):
array = PsLinearizedArray(name, shape, strides, dtype)
ptr = PsArrayBasePointer(expr.name, array)
index = sum([ctr * stride for ctr, stride in zip(CTR_SYMBOLS, expr.field.strides)])
index = sum(
[ctr * stride for ctr, stride in zip(CTR_SYMBOLS, expr.field.strides)]
)
index = self.rec(index)
return PsSymbolExpr(PsArrayAccess(ptr, index))
from __future__ import annotations
from functools import reduce
from typing import TypeAlias, Union, Any, Tuple
from typing import TypeAlias, Any
import pymbolic.primitives as pb
from .types import (
PsAbstractType,
PsScalarType,
PsNumericType,
PsPointerType,
constify,
PsTypeError,
)
......@@ -25,91 +22,6 @@ class PsTypedVariable(pb.Variable):
return self._dtype
class PsArray:
def __init__(
self,
name: str,
length: pb.Expression,
element_type: PsScalarType, # todo Frederik: is PsScalarType correct?
):
self._name = name
self._length = length
self._element_type = element_type
@property
def name(self):
return self._name
@property
def length(self):
return self._length
@property
def element_type(self):
return self._element_type
class PsLinearizedArray(PsArray):
"""N-dimensional contiguous array"""
def __init__(
self,
name: str,
shape: Tuple[pb.Expression, ...],
strides: Tuple[pb.Expression],
element_type: PsScalarType,
):
length = reduce(lambda x, y: x * y, shape)
super().__init__(name, length, element_type)
self._shape = shape
self._strides = strides
@property
def shape(self):
return self._shape
@property
def strides(self):
return self._strides
class PsArrayBasePointer(PsTypedVariable):
def __init__(self, name: str, array: PsArray):
dtype = PsPointerType(array.element_type)
super().__init__(name, dtype)
self._array = array
@property
def array(self):
return self._array
class PsArrayAccess(pb.Subscript):
def __init__(self, base_ptr: PsArrayBasePointer, index: pb.Expression):
super(PsArrayAccess, self).__init__(base_ptr, index)
self._base_ptr = base_ptr
self._index = index
@property
def base_ptr(self):
return self._base_ptr
# @property
# def index(self):
# return self._index
@property
def array(self) -> PsArray:
return self._base_ptr.array
@property
def dtype(self) -> PsAbstractType:
"""Data type of this expression, i.e. the element type of the underlying array"""
return self._base_ptr.array.element_type
class PsTypedConstant:
"""Represents typed constants occuring in the pystencils AST.
......@@ -290,9 +202,7 @@ class PsTypedConstant:
pb.register_constant_class(PsTypedConstant)
PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess]
"""Types of expressions that may occur on the left-hand side of assignments."""
ExprOrConstant: TypeAlias = pb.Expression | PsTypedConstant
"""Required since `PsTypedConstant` does not derive from `pb.Expression`."""
VarOrConstant: TypeAlias = PsTypedVariable | PsTypedConstant
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment