Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (12)
Showing
with 938 additions and 290 deletions
......@@ -14,6 +14,7 @@ who wish to customize or extend the behaviour of the code generator in their app
iteration_space
translation
platforms
transformations
jit
Internal Representation
......@@ -21,7 +22,7 @@ Internal Representation
The code generator translates the kernel from the SymPy frontend's symbolic language to an internal
representation (IR), which is then emitted as code in the required dialect of C.
All names of classes associated with the internal kernel representation are prefixed `Ps...`
All names of classes associated with the internal kernel representation are prefixed ``Ps...``
to distinguis them from identically named front-end and SymPy classes.
The IR comprises *symbols*, *constants*, *arrays*, the *iteration space* and the *abstract syntax tree*:
......
*******************
AST Transformations
*******************
`pystencils.backend.transformations`
.. automodule:: pystencils.backend.transformations
......@@ -5,6 +5,9 @@ Kernel Translation
.. autoclass:: pystencils.backend.kernelcreation.KernelCreationContext
:members:
.. autoclass:: pystencils.backend.kernelcreation.AstFactory
:members:
.. autoclass:: pystencils.backend.kernelcreation.KernelAnalysis
:members:
......
......@@ -18,14 +18,14 @@ from ..exceptions import PsInternalCompilerError
class UndefinedSymbolsCollector:
"""Collector for undefined variables.
"""Collect undefined symbols.
This class implements an AST visitor that collects all `PsTypedVariable`s that have been used
This class implements an AST visitor that collects all symbols that have been used
in the AST without being defined prior to their usage.
"""
def __call__(self, node: PsAstNode) -> set[PsSymbol]:
"""Returns all `PsTypedVariable`s that occur in the given AST without being defined prior to their usage."""
"""Returns all symbols that occur in the given AST without being defined prior to their usage."""
return self.visit(node)
def visit(self, node: PsAstNode) -> set[PsSymbol]:
......@@ -79,8 +79,8 @@ class UndefinedSymbolsCollector:
"""Returns the set of variables declared by the given node which are visible in the enclosing scope."""
match node:
case PsDeclaration(lhs, _):
return {lhs.symbol}
case PsDeclaration():
return {node.declared_symbol}
case (
PsAssignment()
......
......@@ -21,7 +21,7 @@ from .astnode import PsAstNode, PsLeafMixIn
class PsExpression(PsAstNode, ABC):
"""Base class for all expressions.
**Types:** Each expression should be annotated with its type.
Upon construction, the `dtype` property of most expression nodes is unset;
only constant expressions, symbol expressions, and array accesses immediately inherit their type from
......@@ -149,7 +149,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression):
return self._constant == other._constant
def __repr__(self) -> str:
return f"Constant({repr(self._constant)})"
return f"PsConstantExpr({repr(self._constant)})"
class PsSubscript(PsLvalue, PsExpression):
......@@ -271,7 +271,7 @@ class PsVectorArrayAccess(PsArrayAccess):
@property
def alignment(self) -> int:
return self._alignment
def get_vector_type(self) -> PsVectorType:
return cast(PsVectorType, self._dtype)
......@@ -385,6 +385,18 @@ class PsCall(PsExpression):
return super().structurally_equal(other) and self._function == other._function
class PsNumericOpTrait:
"""Trait for operations valid only on numerical types"""
class PsIntOpTrait:
"""Trait for operations valid only on integer types"""
class PsBoolOpTrait:
"""Trait for boolean operations"""
class PsUnOp(PsExpression):
__match_args__ = ("operand",)
......@@ -414,8 +426,12 @@ class PsUnOp(PsExpression):
def python_operator(self) -> None | Callable[[Any], Any]:
return None
def __repr__(self) -> str:
opname = self.__class__.__name__
return f"{opname}({repr(self._operand)})"
class PsNeg(PsUnOp):
class PsNeg(PsUnOp, PsNumericOpTrait):
@property
def python_operator(self):
return operator.neg
......@@ -503,31 +519,31 @@ class PsBinOp(PsExpression):
return None
class PsAdd(PsBinOp):
class PsAdd(PsBinOp, PsNumericOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.add
class PsSub(PsBinOp):
class PsSub(PsBinOp, PsNumericOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.sub
class PsMul(PsBinOp):
class PsMul(PsBinOp, PsNumericOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.mul
class PsDiv(PsBinOp):
class PsDiv(PsBinOp, PsNumericOpTrait):
# python_operator not implemented because can't unambigously decide
# between intdiv and truediv
pass
class PsIntDiv(PsBinOp):
class PsIntDiv(PsBinOp, PsIntOpTrait):
"""C-like integer division (round to zero)."""
# python_operator not implemented because both floordiv and truediv have
......@@ -535,36 +551,94 @@ class PsIntDiv(PsBinOp):
pass
class PsLeftShift(PsBinOp):
class PsLeftShift(PsBinOp, PsIntOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.lshift
class PsRightShift(PsBinOp):
class PsRightShift(PsBinOp, PsIntOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.rshift
class PsBitwiseAnd(PsBinOp):
class PsBitwiseAnd(PsBinOp, PsIntOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.and_
class PsBitwiseXor(PsBinOp):
class PsBitwiseXor(PsBinOp, PsIntOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.xor
class PsBitwiseOr(PsBinOp):
class PsBitwiseOr(PsBinOp, PsIntOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.or_
class PsAnd(PsBinOp, PsBoolOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.and_
class PsOr(PsBinOp, PsBoolOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.or_
class PsNot(PsUnOp, PsBoolOpTrait):
@property
def python_operator(self) -> Callable[[Any], Any] | None:
return operator.not_
class PsRel(PsBinOp):
"""Base class for binary relational operators"""
class PsEq(PsRel):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.eq
class PsNe(PsRel):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.ne
class PsGe(PsRel):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.ge
class PsLe(PsRel):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.le
class PsGt(PsRel):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.gt
class PsLt(PsRel):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.lt
class PsArrayInitList(PsExpression):
__match_args__ = ("items",)
......
from typing import Callable, Any
import operator
from .expressions import PsExpression
from .astnode import PsAstNode
from .util import failing_cast
class PsLogicalExpression(PsExpression):
__match_args__ = ("operand1", "operand2")
def __init__(self, op1: PsExpression, op2: PsExpression):
super().__init__()
self._op1 = op1
self._op2 = op2
@property
def operand1(self) -> PsExpression:
return self._op1
@operand1.setter
def operand1(self, expr: PsExpression):
self._op1 = expr
@property
def operand2(self) -> PsExpression:
return self._op2
@operand2.setter
def operand2(self, expr: PsExpression):
self._op2 = expr
def clone(self):
return type(self)(self._op1.clone(), self._op2.clone())
def get_children(self) -> tuple[PsAstNode, ...]:
return self._op1, self._op2
def set_child(self, idx: int, c: PsAstNode):
idx = [0, 1][idx]
match idx:
case 0:
self._op1 = failing_cast(PsExpression, c)
case 1:
self._op2 = failing_cast(PsExpression, c)
def __repr__(self) -> str:
opname = self.__class__.__name__
return f"{opname}({repr(self._op1)}, {repr(self._op2)})"
@property
def python_operator(self) -> None | Callable[[Any, Any], Any]:
return None
class PsAnd(PsLogicalExpression):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.and_
class PsEq(PsLogicalExpression):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.eq
class PsGe(PsLogicalExpression):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.ge
class PsGt(PsLogicalExpression):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.gt
class PsLe(PsLogicalExpression):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.le
class PsLt(PsLogicalExpression):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.lt
class PsNe(PsLogicalExpression):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.ne
......@@ -4,6 +4,7 @@ from types import NoneType
from .astnode import PsAstNode, PsLeafMixIn
from .expressions import PsExpression, PsLvalue, PsSymbolExpr
from ..symbols import PsSymbol
from .util import failing_cast
......@@ -121,7 +122,7 @@ class PsAssignment(PsAstNode):
class PsDeclaration(PsAssignment):
__match_args__ = (
"declared_variable",
"lhs",
"rhs",
)
......@@ -137,12 +138,8 @@ class PsDeclaration(PsAssignment):
self._lhs = failing_cast(PsSymbolExpr, lvalue)
@property
def declared_variable(self) -> PsSymbolExpr:
return cast(PsSymbolExpr, self._lhs)
@declared_variable.setter
def declared_variable(self, lvalue: PsSymbolExpr):
self._lhs = lvalue
def declared_symbol(self) -> PsSymbol:
return cast(PsSymbolExpr, self._lhs).symbol
def clone(self) -> PsDeclaration:
return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone())
......
from __future__ import annotations
from typing import Any
from ..types import PsNumericType, constify
......@@ -5,6 +6,21 @@ from .exceptions import PsInternalCompilerError
class PsConstant:
"""Type-safe representation of typed numerical constants.
This class models constants in the backend representation of kernels.
A constant may be *untyped*, in which case its ``value`` may be any Python object.
If the constant is *typed* (i.e. its ``dtype`` is not ``None``), its data type is used
to check the validity of its ``value`` and to convert it into the type's internal representation.
Instances of `PsConstant` are immutable.
Args:
value: The constant's value
dtype: The constant's data type, or ``None`` if untyped.
"""
__match_args__ = ("value", "dtype")
def __init__(self, value: Any, dtype: PsNumericType | None = None):
......@@ -12,7 +28,30 @@ class PsConstant:
self._value = value
if dtype is not None:
self.apply_dtype(dtype)
self._dtype = constify(dtype)
self._value = self._dtype.create_constant(self._value)
else:
self._dtype = None
self._value = value
def interpret_as(self, dtype: PsNumericType) -> PsConstant:
"""Interprets this *untyped* constant with the given data type.
If this constant is already typed, raises an error.
"""
if self._dtype is not None:
raise PsInternalCompilerError(
f"Cannot interpret already typed constant {self} with type {dtype}"
)
return PsConstant(self._value, dtype)
def reinterpret_as(self, dtype: PsNumericType) -> PsConstant:
"""Reinterprets this constant with the given data type.
Other than `interpret_as`, this method also works on typed constants.
"""
return PsConstant(self._value, dtype)
@property
def value(self) -> Any:
......@@ -20,22 +59,18 @@ class PsConstant:
@property
def dtype(self) -> PsNumericType | None:
"""This constant's data type, or ``None`` if it is untyped.
The data type of a constant always has ``const == True``.
"""
return self._dtype
def get_dtype(self) -> PsNumericType:
"""Retrieve this constant's data type, throwing an exception if the constant is untyped."""
if self._dtype is None:
raise PsInternalCompilerError("Data type of constant was not set.")
return self._dtype
def apply_dtype(self, dtype: PsNumericType):
if self._dtype is not None:
raise PsInternalCompilerError(
"Attempt to apply data type to already typed constant."
)
self._dtype = constify(dtype)
self._value = self._dtype.create_constant(self._value)
def __str__(self) -> str:
type_str = "<untyped>" if self._dtype is None else str(self._dtype)
return f"{str(self._value)}: {type_str}"
......
......@@ -35,6 +35,15 @@ from .ast.expressions import (
PsSubscript,
PsSymbolExpr,
PsVectorArrayAccess,
PsAnd,
PsOr,
PsNot,
PsEq,
PsNe,
PsGt,
PsLt,
PsGe,
PsLe,
)
from .symbols import PsSymbol
......@@ -67,32 +76,41 @@ class Ops(Enum):
See also https://en.cppreference.com/w/cpp/language/operator_precedence
"""
Weakest = (17 - 17, LR.Middle)
Call = (2, LR.Left)
Subscript = (2, LR.Left)
Lookup = (2, LR.Left)
BitwiseOr = (17 - 13, LR.Left)
Neg = (3, LR.Right)
Not = (3, LR.Right)
AddressOf = (3, LR.Right)
Deref = (3, LR.Right)
Cast = (3, LR.Right)
BitwiseXor = (17 - 12, LR.Left)
Mul = (5, LR.Left)
Div = (5, LR.Left)
Rem = (5, LR.Left)
BitwiseAnd = (17 - 11, LR.Left)
Add = (6, LR.Left)
Sub = (6, LR.Left)
LeftShift = (17 - 7, LR.Left)
RightShift = (17 - 7, LR.Left)
LeftShift = (7, LR.Left)
RightShift = (7, LR.Left)
Add = (17 - 6, LR.Left)
Sub = (17 - 6, LR.Left)
RelOp = (9, LR.Left) # >=, >, <, <=
Mul = (17 - 5, LR.Left)
Div = (17 - 5, LR.Left)
Rem = (17 - 5, LR.Left)
EqOp = (10, LR.Left) # == and !=
Neg = (17 - 3, LR.Right)
AddressOf = (17 - 3, LR.Right)
Deref = (17 - 3, LR.Right)
Cast = (17 - 3, LR.Right)
BitwiseAnd = (11, LR.Left)
Call = (17 - 2, LR.Left)
Subscript = (17 - 2, LR.Left)
Lookup = (17 - 2, LR.Left)
BitwiseXor = (12, LR.Left)
BitwiseOr = (13, LR.Left)
LogicAnd = (14, LR.Left)
LogicOr = (15, LR.Left)
Weakest = (17, LR.Middle)
def __init__(self, pred: int, assoc: LR) -> None:
self.precedence = pred
......@@ -125,7 +143,7 @@ class PrinterCtx:
return self.branch_stack[-1]
def parenthesize(self, expr: str, next_operator: Ops) -> str:
if next_operator.precedence < self.current_op.precedence:
if next_operator.precedence > self.current_op.precedence:
return f"({expr})"
elif (
next_operator.precedence == self.current_op.precedence
......@@ -169,7 +187,7 @@ class CAstPrinter:
return pc.indent(f"{self.visit(expr, pc)};")
case PsDeclaration(lhs, rhs):
lhs_symb = lhs.symbol
lhs_symb = node.declared_symbol
lhs_code = self._symbol_decl(lhs_symb)
rhs_code = self.visit(rhs, pc)
......@@ -274,6 +292,13 @@ class CAstPrinter:
return pc.parenthesize(f"-{operand_code}", Ops.Neg)
case PsNot(operand):
pc.push_op(Ops.Not, LR.Right)
operand_code = self.visit(operand, pc)
pc.pop_op()
return pc.parenthesize(f"!{operand_code}", Ops.Not)
case PsDeref(operand):
pc.push_op(Ops.Deref, LR.Right)
operand_code = self.visit(operand, pc)
......@@ -339,5 +364,21 @@ class CAstPrinter:
return ("^", Ops.BitwiseXor)
case PsBitwiseOr():
return ("|", Ops.BitwiseOr)
case PsAnd():
return ("&&", Ops.LogicAnd)
case PsOr():
return ("||", Ops.LogicOr)
case PsEq():
return ("==", Ops.EqOp)
case PsNe():
return ("!=", Ops.EqOp)
case PsGt():
return (">", Ops.RelOp)
case PsGe():
return (">=", Ops.RelOp)
case PsLt():
return ("<", Ops.RelOp)
case PsLe():
return ("<=", Ops.RelOp)
case _:
assert False
......@@ -2,6 +2,7 @@ from .context import KernelCreationContext
from .analysis import KernelAnalysis
from .freeze import FreezeExpressions
from .typification import Typifier
from .ast_factory import AstFactory
from .iteration_space import (
FullIterationSpace,
......@@ -17,6 +18,7 @@ __all__ = [
"KernelAnalysis",
"FreezeExpressions",
"Typifier",
"AstFactory",
"FullIterationSpace",
"SparseIterationSpace",
"create_full_iteration_space",
......
from typing import Any, Sequence, cast, overload
import numpy as np
import sympy as sp
from sympy.codegen.ast import AssignmentBase
from ..ast import PsAstNode
from ..ast.expressions import PsExpression, PsSymbolExpr, PsConstantExpr
from ..ast.structural import PsLoop, PsBlock, PsAssignment
from ..symbols import PsSymbol
from ..constants import PsConstant
from .context import KernelCreationContext
from .freeze import FreezeExpressions
from .typification import Typifier
from .iteration_space import FullIterationSpace
IndexParsable = PsExpression | PsSymbol | PsConstant | sp.Expr | int | np.integer
_IndexParsable = (PsExpression, PsSymbol, PsConstant, sp.Expr, int, np.integer)
class AstFactory:
"""Factory providing a convenient interface for building syntax trees.
The `AstFactory` uses the defaults provided by the given `KernelCreationContext` to quickly create
AST nodes. Depending on context (numerical, loop indexing, etc.), symbols and constants receive either
``ctx.default_dtype`` or ``ctx.index_dtype``.
Args:
ctx: The kernel creation context
"""
def __init__(self, ctx: KernelCreationContext):
self._ctx = ctx
self._freeze = FreezeExpressions(ctx)
self._typify = Typifier(ctx)
@overload
def parse_sympy(self, sp_obj: sp.Expr) -> PsExpression:
pass
@overload
def parse_sympy(self, sp_obj: AssignmentBase) -> PsAssignment:
pass
def parse_sympy(self, sp_obj: sp.Expr | AssignmentBase) -> PsAstNode:
"""Parse a SymPy expression or assignment through `FreezeExpressions` and `Typifier`.
The expression or assignment will be typified in a numerical context, using the kernel
creation context's `default_dtype`.
Args:
sp_obj: A SymPy expression or assignment
"""
return self._typify(self._freeze(sp_obj))
@overload
def parse_index(self, idx: sp.Symbol | PsSymbol | PsSymbolExpr) -> PsSymbolExpr:
pass
@overload
def parse_index(
self, idx: int | np.integer | PsConstant | PsConstantExpr
) -> PsConstantExpr:
pass
@overload
def parse_index(self, idx: sp.Expr | PsExpression) -> PsExpression:
pass
def parse_index(self, idx: IndexParsable):
"""Parse the given object as an expression with data type `ctx.index_dtype`."""
if not isinstance(idx, _IndexParsable):
raise TypeError(
f"Cannot parse object of type {type(idx)} as an index expression"
)
match idx:
case PsExpression():
return self._typify.typify_expression(idx, self._ctx.index_dtype)[0]
case PsSymbol() | PsConstant():
return self._typify.typify_expression(
PsExpression.make(idx), self._ctx.index_dtype
)[0]
case sp.Expr():
return self._typify.typify_expression(
self._freeze(idx), self._ctx.index_dtype
)[0]
case _:
return PsExpression.make(PsConstant(idx, self._ctx.index_dtype))
def _parse_any_index(self, idx: Any) -> PsExpression:
return self.parse_index(cast(IndexParsable, idx))
def parse_slice(
self, slic: slice, upper_limit: Any | None = None
) -> tuple[PsExpression, PsExpression, PsExpression]:
"""Parse a slice to obtain start, stop and step expressions for a loop or iteration space dimension.
The slice entries may be instances of `PsExpression`, `PsSymbol` or `PsConstant`, in which case they
must typify with the kernel creation context's ``index_dtype``.
They may also be sympy expressions or integer constants, in which case they are parsed to AST objects
and must also typify with the kernel creation context's ``index_dtype``.
If the slice's ``stop`` member is `None` or a negative `int`, `upper_limit` must be specified, which is then
used as the upper iteration limit as either ``upper_limit`` or ``upper_limit - stop``.
Args:
slic: The iteration slice
upper_limit: Optionally, the upper iteration limit
"""
if slic.stop is None or (isinstance(slic.stop, int) and slic.stop < 0):
if upper_limit is None:
raise ValueError(
"Must specify an upper iteration limit if `slice.stop` is `None` or a negative `int`"
)
start = self._parse_any_index(slic.start if slic.start is not None else 0)
stop = (
self._parse_any_index(slic.stop)
if slic.stop is not None
else self._parse_any_index(upper_limit)
)
step = self._parse_any_index(slic.step if slic.step is not None else 1)
if isinstance(slic.stop, int) and slic.stop < 0:
stop = self._parse_any_index(upper_limit) + stop
return start, stop, step
def loop(self, ctr_name: str, iteration_slice: slice, body: PsBlock):
"""Create a loop from a slice.
Args:
ctr_name: Name of the loop counter
iteration_slice: The iteration region as a slice; see `parse_slice`.
body: The loop body
"""
ctr = PsExpression.make(self._ctx.get_symbol(ctr_name, self._ctx.index_dtype))
start, stop, step = self.parse_slice(iteration_slice)
return PsLoop(
ctr,
start,
stop,
step,
body,
)
def loop_nest(
self, counters: Sequence[str], slices: Sequence[slice], body: PsBlock
) -> PsLoop:
"""Create a loop nest from a sequence of slices.
**Example:**
This snippet creates a 3D loop nest with ten iterations in each dimension::
>>> from pystencils import make_slice
>>> ctx = KernelCreationContext()
>>> factory = AstFactory(ctx)
>>> loop = factory.loop_nest(("i", "j", "k"), make_slice[:10,:10,:10], PsBlock([]))
Args:
counters: Sequence of names for the loop counters
slices: Sequence of iteration slices; see also `parse_slice`
body: The loop body
"""
if not slices:
raise ValueError(
"At least one slice must be specified to create a loop nest."
)
ast = body
for ctr_name, sl in zip(counters[::-1], slices[::-1], strict=True):
ast = self.loop(
ctr_name,
sl,
PsBlock([ast]) if not isinstance(ast, PsBlock) else ast,
)
return cast(PsLoop, ast)
def loops_from_ispace(
self,
ispace: FullIterationSpace,
body: PsBlock,
loop_order: Sequence[int] | None = None,
) -> PsLoop:
"""Create a loop nest from a dense iteration space.
Args:
ispace: The iteration space object
body: The loop body
loop_order: Optionally, a permutation of integers indicating the order of loops
"""
dimensions = ispace.dimensions
if loop_order is not None:
dimensions = [dimensions[coordinate] for coordinate in loop_order]
outer_node: PsLoop | PsBlock = body
for dimension in dimensions[::-1]:
outer_node = PsLoop(
PsSymbolExpr(dimension.counter),
dimension.start,
dimension.stop,
dimension.step,
(
outer_node
if isinstance(outer_node, PsBlock)
else PsBlock([outer_node])
),
)
assert isinstance(outer_node, PsLoop)
return outer_node
from __future__ import annotations
from typing import Iterable, Iterator
from itertools import chain
from itertools import chain, count
from types import EllipsisType
from collections import namedtuple
from collections import namedtuple, defaultdict
import re
from ...defaults import DEFAULTS
from ...field import Field, FieldType
......@@ -67,6 +68,9 @@ class KernelCreationContext:
self._symbols: dict[str, PsSymbol] = dict()
self._symbol_ctr_pattern = re.compile(r"__[0-9]+$")
self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0)
self._fields_and_arrays: dict[str, FieldArrayPair] = dict()
self._fields_collection = FieldsInKernel()
......@@ -95,6 +99,21 @@ class KernelCreationContext:
# Symbols
def get_symbol(self, name: str, dtype: PsType | None = None) -> PsSymbol:
"""Retrieve the symbol with the given name and data type from the symbol table.
If no symbol named ``name`` exists, a new symbol with the given data type is created.
If a symbol with the given ``name`` already exists and ``dtype`` is not `None`,
the given data type will be applied to it, and it is returned.
If the symbol already has a different data type, an error will be raised.
If the symbol already exists and ``dtype`` is `None`, the existing symbol is returned
without checking or altering its data type.
Args:
name: The symbol's name
dtype: The symbol's data type, or `None`
"""
if name not in self._symbols:
symb = PsSymbol(name, None)
self._symbols[name] = symb
......@@ -106,13 +125,29 @@ class KernelCreationContext:
return symb
def find_symbol(self, name: str) -> PsSymbol | None:
"""Find a symbol with the given name in the symbol table, if it exists.
Returns:
The symbol with the given name, or `None` if no such symbol exists.
"""
return self._symbols.get(name, None)
def add_symbol(self, symbol: PsSymbol):
"""Add an existing symbol to the symbol table.
If a symbol with the same name already exists, an error will be raised.
"""
if symbol.name in self._symbols:
raise PsInternalCompilerError(f"Duplicate symbol: {symbol.name}")
self._symbols[symbol.name] = symbol
def replace_symbol(self, old: PsSymbol, new: PsSymbol):
"""Replace one symbol by another.
The two symbols ``old`` and ``new`` must have the same name, but may have different data types.
"""
if old.name != new.name:
raise PsInternalCompilerError(
"replace_symbol: Old and new symbol must have the same name"
......@@ -123,8 +158,30 @@ class KernelCreationContext:
self._symbols[old.name] = new
def duplicate_symbol(self, symb: PsSymbol) -> PsSymbol:
"""Canonically duplicates the given symbol.
A new symbol with the same data type, and new name ``symb.name + "__<counter>"`` is created,
added to the symbol table, and returned.
The ``counter`` reflects the number of previously created duplicates of this symbol.
"""
if (result := self._symbol_ctr_pattern.search(symb.name)) is not None:
span = result.span()
basename = symb.name[: span[0]]
else:
basename = symb.name
initial_count = self._symbol_dup_table[basename]
for i in count(initial_count):
dup_name = f"{basename}__{i}"
if self.find_symbol(dup_name) is None:
self._symbol_dup_table[basename] = i + 1
return self.get_symbol(dup_name, symb.dtype)
assert False, "unreachable code"
@property
def symbols(self) -> Iterable[PsSymbol]:
"""Return an iterable of all symbols listed in the symbol table."""
return self._symbols.values()
# Fields and Arrays
......
from __future__ import annotations
from typing import cast
from .context import KernelCreationContext
from ..platforms import GenericCpu
from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations
from ..ast.structural import PsBlock
from ...config import CpuOptimConfig
......@@ -11,10 +13,19 @@ def optimize_cpu(
ctx: KernelCreationContext,
platform: GenericCpu,
kernel_ast: PsBlock,
cfg: CpuOptimConfig,
):
cfg: CpuOptimConfig | None,
) -> PsBlock:
"""Carry out CPU-specific optimizations according to the given configuration."""
canonicalize = CanonicalizeSymbols(ctx, True)
kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
hoist_invariants = HoistLoopInvariantDeclarations(ctx)
kernel_ast = cast(PsBlock, hoist_invariants(kernel_ast))
if cfg is None:
return kernel_ast
if cfg.loop_blocking:
raise NotImplementedError("Loop blocking not implemented yet.")
......@@ -26,3 +37,5 @@ def optimize_cpu(
if cfg.use_cacheline_zeroing:
raise NotImplementedError("CL-zeroing not implemented yet")
return kernel_ast
from typing import overload, cast, Any
from functools import reduce
from operator import add, mul, sub
from operator import add, mul, sub, truediv
import sympy as sp
import sympy.core.relational
import sympy.logic.boolalg
from sympy.codegen.ast import AssignmentBase, AugmentedAssignment
from ...sympyextensions import Assignment, AssignmentCollection, integer_functions
from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc
......@@ -33,6 +36,16 @@ from ..ast.expressions import (
PsRightShift,
PsSubscript,
PsVectorArrayAccess,
PsRel,
PsEq,
PsNe,
PsLt,
PsGt,
PsLe,
PsGe,
PsAnd,
PsOr,
PsNot,
)
from ..constants import PsConstant
......@@ -45,6 +58,20 @@ class FreezeError(Exception):
"""Signifies an error during expression freezing."""
ExprLike = (
sp.Expr
| sp.Tuple
| sympy.core.relational.Relational
| sympy.logic.boolalg.BooleanFunction
)
_ExprLike = (
sp.Expr,
sp.Tuple,
sympy.core.relational.Relational,
sympy.logic.boolalg.BooleanFunction,
)
class FreezeExpressions:
"""Convert expressions and kernels expressed in the SymPy language to the code generator's internal representation.
......@@ -64,19 +91,19 @@ class FreezeExpressions:
pass
@overload
def __call__(self, obj: sp.Expr) -> PsExpression:
def __call__(self, obj: ExprLike) -> PsExpression:
pass
@overload
def __call__(self, obj: Assignment) -> PsAssignment:
def __call__(self, obj: AssignmentBase) -> PsAssignment:
pass
def __call__(self, obj: AssignmentCollection | sp.Basic) -> PsAstNode:
if isinstance(obj, AssignmentCollection):
return PsBlock([self.visit(asm) for asm in obj.all_assignments])
elif isinstance(obj, Assignment):
elif isinstance(obj, AssignmentBase):
return cast(PsAssignment, self.visit(obj))
elif isinstance(obj, sp.Expr):
elif isinstance(obj, _ExprLike):
return cast(PsExpression, self.visit(obj))
else:
raise PsInputError(f"Don't know how to freeze {obj}")
......@@ -96,8 +123,8 @@ class FreezeExpressions:
raise FreezeError(f"Don't know how to freeze expression {node}")
def visit_expr_like(self, obj: Any) -> PsExpression:
if isinstance(obj, sp.Basic):
def visit_expr_or_builtin(self, obj: Any) -> PsExpression:
if isinstance(obj, _ExprLike):
return self.visit_expr(obj)
elif isinstance(obj, (int, float, bool)):
return PsExpression.make(PsConstant(obj))
......@@ -105,7 +132,7 @@ class FreezeExpressions:
raise FreezeError(f"Don't know how to freeze {obj}")
def visit_expr(self, expr: sp.Basic):
if not isinstance(expr, (sp.Expr, sp.Tuple)):
if not isinstance(expr, _ExprLike):
raise FreezeError(f"Cannot freeze {expr} to an expression")
return cast(PsExpression, self.visit(expr))
......@@ -128,6 +155,27 @@ class FreezeExpressions:
f"Encountered unsupported expression on assignment left-hand side: {lhs}"
)
def map_AugmentedAssignment(self, expr: AugmentedAssignment):
lhs = self.visit(expr.lhs)
rhs = self.visit(expr.rhs)
assert isinstance(lhs, PsExpression)
assert isinstance(rhs, PsExpression)
match expr.op:
case "+=":
op = add
case "-=":
op = sub
case "*=":
op = mul
case "/=":
op = truediv
case _:
raise FreezeError(f"Unsupported augmented assignment: {expr.op}.")
return PsAssignment(lhs, op(lhs.clone(), rhs))
def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr:
symb = self._ctx.get_symbol(spsym.name)
return PsSymbolExpr(symb)
......@@ -235,7 +283,9 @@ class FreezeExpressions:
array = self._ctx.get_array(field)
ptr = array.base_pointer
offsets: list[PsExpression] = [self.visit_expr_like(o) for o in access.offsets]
offsets: list[PsExpression] = [
self.visit_expr_or_builtin(o) for o in access.offsets
]
indices: list[PsExpression]
if not access.is_absolute_access:
......@@ -281,7 +331,7 @@ class FreezeExpressions:
)
else:
struct_member_name = None
indices = [self.visit_expr_like(i) for i in access.index]
indices = [self.visit_expr_or_builtin(i) for i in access.index]
if not indices:
# For canonical representation, there must always be at least one index dimension
indices = [PsExpression.make(PsConstant(0))]
......@@ -349,5 +399,35 @@ class FreezeExpressions:
args = tuple(self.visit_expr(arg) for arg in expr.args)
return PsCall(PsMathFunction(MathFunctions.Max), args)
def map_CastFunc(self, cast_expr: CastFunc):
def map_CastFunc(self, cast_expr: CastFunc) -> PsCast:
return PsCast(cast_expr.dtype, self.visit_expr(cast_expr.expr))
def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel:
arg1, arg2 = [self.visit_expr(arg) for arg in rel.args]
match rel.rel_op: # type: ignore
case "==":
return PsEq(arg1, arg2)
case "!=":
return PsNe(arg1, arg2)
case ">=":
return PsGe(arg1, arg2)
case "<=":
return PsLe(arg1, arg2)
case ">":
return PsGt(arg1, arg2)
case "<":
return PsLt(arg1, arg2)
case other:
raise FreezeError(f"Unsupported relation: {other}")
def map_And(self, conj: sympy.logic.And) -> PsAnd:
arg1, arg2 = [self.visit_expr(arg) for arg in conj.args]
return PsAnd(arg1, arg2)
def map_Or(self, disj: sympy.logic.Or) -> PsOr:
arg1, arg2 = [self.visit_expr(arg) for arg in disj.args]
return PsOr(arg1, arg2)
def map_Not(self, neg: sympy.logic.Not) -> PsNot:
arg = self.visit_expr(neg.args[0])
return PsNot(arg)
......@@ -5,8 +5,6 @@ from dataclasses import dataclass
from functools import reduce
from operator import mul
import sympy as sp
from ...defaults import DEFAULTS
from ...sympyextensions import AssignmentCollection
from ...field import Field, FieldType
......@@ -71,8 +69,8 @@ class FullIterationSpace(IterationSpace):
@staticmethod
def create_with_ghost_layers(
ctx: KernelCreationContext,
archetype_field: Field,
ghost_layers: int | Sequence[int | tuple[int, int]],
archetype_field: Field,
) -> FullIterationSpace:
"""Create an iteration space over an archetype field with ghost layers."""
......@@ -123,56 +121,56 @@ class FullIterationSpace(IterationSpace):
@staticmethod
def create_from_slice(
ctx: KernelCreationContext,
archetype_field: Field,
iteration_slice: Sequence[slice],
iteration_slice: slice | Sequence[slice],
archetype_field: Field | None = None,
):
archetype_array = ctx.get_array(archetype_field)
dim = archetype_field.spatial_dimensions
if len(iteration_slice) != dim:
"""Create an iteration space from a sequence of slices, optionally over an archetype field.
Args:
ctx: The kernel creation context
iteration_slice: The iteration slices for each dimension; for valid formats, see `AstFactory.parse_slice`
archetype_field: Optionally, an archetype field that dictates the upper slice limits and loop order.
"""
if isinstance(iteration_slice, slice):
iteration_slice = (iteration_slice,)
dim = len(iteration_slice)
if dim == 0:
raise ValueError(
f"Number of dimensions in slice ({len(iteration_slice)}) "
f" did not equal iteration space dimensionality ({dim})"
"At least one slice must be specified to create an iteration space"
)
archetype_size: tuple[PsSymbol | PsConstant | None, ...]
if archetype_field is not None:
archetype_array = ctx.get_array(archetype_field)
if archetype_field.spatial_dimensions != dim:
raise ValueError(
f"Number of dimensions in slice ({len(iteration_slice)}) "
f" did not equal iteration space dimensionality ({dim})"
)
archetype_size = archetype_array.shape[:dim]
else:
archetype_size = (None,) * dim
counters = [
ctx.get_symbol(name, ctx.index_dtype)
for name in DEFAULTS.spatial_counter_names[:dim]
]
from .freeze import FreezeExpressions
from .typification import Typifier
freeze = FreezeExpressions(ctx)
typifier = Typifier(ctx)
def expr_convert(expr) -> PsExpression:
if isinstance(expr, int):
return PsConstantExpr(PsConstant(expr, ctx.index_dtype))
elif isinstance(expr, sp.Expr):
typed_expr, _ = typifier.typify_expression(
freeze.freeze_expression(expr), ctx.index_dtype
)
return typed_expr
else:
raise ValueError(f"Invalid entry in slice: {expr}")
def to_dim(slic: slice, size: PsSymbol | PsConstant, ctr: PsSymbol):
size_expr = PsExpression.make(size)
start = expr_convert(slic.start if slic.start is not None else 0)
stop = expr_convert(slic.stop) if slic.stop is not None else size_expr
step = expr_convert(slic.step if slic.step is not None else 1)
from .ast_factory import AstFactory
if isinstance(slic.stop, int) and slic.stop < 0:
stop = size_expr + stop # todo
factory = AstFactory(ctx)
def to_dim(slic: slice, size: PsSymbol | PsConstant | None, ctr: PsSymbol):
start, stop, step = factory.parse_slice(slic, size)
return FullIterationSpace.Dimension(start, stop, step, ctr)
dimensions = [
to_dim(slic, size, ctr)
for slic, size, ctr in zip(
iteration_slice, archetype_array.shape[:dim], counters, strict=True
iteration_slice, archetype_size, counters, strict=True
)
]
......@@ -399,13 +397,13 @@ def create_full_iteration_space(
if ghost_layers is not None:
return FullIterationSpace.create_with_ghost_layers(
ctx, archetype_field, ghost_layers
ctx, ghost_layers, archetype_field
)
elif iteration_slice is not None:
return FullIterationSpace.create_from_slice(
ctx, archetype_field, iteration_slice
ctx, iteration_slice, archetype_field
)
else:
return FullIterationSpace.create_with_ghost_layers(
ctx, archetype_field, inferred_gls
ctx, inferred_gls, archetype_field
)
......@@ -12,7 +12,7 @@ from ...types import (
PsDereferencableType,
PsPointerType,
PsBoolType,
deconstify,
constify,
)
from ..ast.structural import (
PsAstNode,
......@@ -21,25 +21,27 @@ from ..ast.structural import (
PsConditional,
PsExpression,
PsAssignment,
PsDeclaration,
PsComment,
)
from ..ast.expressions import (
PsArrayAccess,
PsArrayInitList,
PsBinOp,
PsBitwiseAnd,
PsBitwiseOr,
PsBitwiseXor,
PsIntOpTrait,
PsNumericOpTrait,
PsBoolOpTrait,
PsCall,
PsCast,
PsDeref,
PsAddressOf,
PsConstantExpr,
PsIntDiv,
PsLeftShift,
PsLookup,
PsRightShift,
PsSubscript,
PsSymbolExpr,
PsRel,
PsNeg,
PsNot,
)
from ..functions import PsMathFunction
......@@ -54,20 +56,48 @@ NodeT = TypeVar("NodeT", bound=PsAstNode)
class TypeContext:
def __init__(self, target_type: PsType | None = None):
self._target_type = deconstify(target_type) if target_type is not None else None
"""Typing context, with support for type inference and checking.
Instances of this class are used to propagate and check data types across expression subtrees
of the AST. Each type context has:
- A target type `target_type`, which shall be applied to all expressions it covers
- A set of restrictions on the target type:
- `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides
- Additional restrictions may be added in the future.
"""
def __init__(
self, target_type: PsType | None = None, require_nonconst: bool = False
):
self._require_nonconst = require_nonconst
self._deferred_exprs: list[PsExpression] = []
def apply_dtype(self, expr: PsExpression | None, dtype: PsType):
"""Applies the given ``dtype`` to the given expression inside this type context.
self._target_type = (
self._fix_constness(target_type) if target_type is not None else None
)
@property
def target_type(self) -> PsType | None:
return self._target_type
@property
def require_nonconst(self) -> bool:
return self._require_nonconst
def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None):
"""Applies the given ``dtype`` to this type context, and optionally to the given expression.
The given expression will be covered by this type context.
If the context's target_type is already known, it must be compatible with the given dtype.
If the target type is still unknown, target_type is set to dtype and retroactively applied
to all deferred expressions.
If an expression is specified, it will be covered by the type context.
If the expression already has a data type set, it must be compatible with the target type
and will be replaced by it.
"""
dtype = deconstify(dtype)
dtype = self._fix_constness(dtype)
if self._target_type is not None and dtype != self._target_type:
raise TypificationError(
......@@ -80,14 +110,7 @@ class TypeContext:
self._propagate_target_type()
if expr is not None:
if expr.dtype is None:
self._apply_target_type(expr)
elif deconstify(expr.dtype) != self._target_type:
raise TypificationError(
"Type conflict: Predefined expression type did not match the context's target type\n"
f" Expression type: {dtype}\n"
f" Target type: {self._target_type}"
)
self._apply_target_type(expr)
def infer_dtype(self, expr: PsExpression):
"""Infer the data type for the given expression.
......@@ -96,7 +119,8 @@ class TypeContext:
Otherwise, the expression is deferred, and a type will be applied to it as soon as `apply_type` is
called on this context.
It the expression already has a data type set, it must be equal to the inferred type.
If the expression already has a data type set, it must be compatible with the target type
and will be replaced by it.
"""
if self._target_type is None:
......@@ -113,7 +137,7 @@ class TypeContext:
assert self._target_type is not None
if expr.dtype is not None:
if deconstify(expr.dtype) != self.target_type:
if not self._compatible(expr.dtype):
raise TypificationError(
f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
f" Expression type: {expr.dtype}\n"
......@@ -126,36 +150,83 @@ class TypeContext:
raise TypificationError(
f"Can't typify constant with non-numeric type {self._target_type}"
)
c.apply_dtype(self._target_type)
if c.dtype is None:
expr.constant = c.interpret_as(self._target_type)
elif not self._compatible(c.dtype):
raise TypificationError(
f"Type mismatch at constant {c}: Constant type did not match the context's target type\n"
f" Constant type: {c.dtype}\n"
f" Target type: {self._target_type}"
)
case PsSymbolExpr(symb):
symb.apply_dtype(self._target_type)
case (
PsIntDiv()
| PsLeftShift()
| PsRightShift()
| PsBitwiseAnd()
| PsBitwiseXor()
| PsBitwiseOr()
) if not isinstance(self._target_type, PsIntegerType):
assert symb.dtype is not None
if not self._compatible(symb.dtype):
raise TypificationError(
f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n"
f" Symbol type: {symb.dtype}\n"
f" Target type: {self._target_type}"
)
case PsNumericOpTrait() if not isinstance(
self._target_type, PsNumericType
) or isinstance(self._target_type, PsBoolType):
# FIXME: PsBoolType derives from PsNumericType, but is not numeric
raise TypificationError(
f"Numerical operation encountered in non-numerical type context:\n"
f" Expression: {expr}"
f" Type Context: {self._target_type}"
)
case PsIntOpTrait() if not isinstance(self._target_type, PsIntegerType):
raise TypificationError(
f"Integer operation encountered in non-integer type context:\n"
f" Expression: {expr}"
f" Type Context: {self._target_type}"
)
expr.dtype = self._target_type
case PsBoolOpTrait() if not isinstance(self._target_type, PsBoolType):
raise TypificationError(
f"Boolean operation encountered in non-boolean type context:\n"
f" Expression: {expr}"
f" Type Context: {self._target_type}"
)
# endif
expr.dtype = self._target_type
@property
def target_type(self) -> PsType | None:
return self._target_type
def _compatible(self, dtype: PsType):
"""Checks whether the given data type is compatible with the context's target type.
If the target type is ``const``, they must be equal up to const qualification;
if the target type is not ``const``, `dtype` must match it exactly.
"""
assert self._target_type is not None
if self._target_type.const:
return constify(dtype) == self._target_type
else:
return dtype == self._target_type
def _fix_constness(self, dtype: PsType, expr: PsExpression | None = None):
if self._require_nonconst:
if dtype.const:
if expr is None:
raise TypificationError(
f"Type mismatch: Encountered {dtype} in non-constant context."
)
else:
raise TypificationError(
f"Type mismatch at expression {expr}: Encountered {dtype} in non-constant context."
)
return dtype
else:
return constify(dtype)
class Typifier:
"""Apply data types to expressions.
**Contextual Typing**
The Typifier will traverse the AST and apply a contextual typing scheme to figure out
the data types of all encountered expressions.
To this end, it covers each expression tree with a set of disjoint typing contexts.
......@@ -176,6 +247,21 @@ class Typifier:
in the context, the expression is deferred by storing it in the context, and will be assigned a type as soon
as the target type is fixed.
**Typing Rules**
The following general rules apply:
- The context's `default_dtype` is applied to all untyped symbols
- By default, all expressions receive a ``const`` type unless they occur on a (non-declaration) assignment's
left-hand side
**Typing of symbol expressions**
Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types, but
not necessarily their const-qualification.
A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type,
and an array base pointer with non-``const`` base type may be nested in a ``const`` `PsArrayAccess`,
but not vice versa.
"""
def __init__(self, ctx: KernelCreationContext):
......@@ -206,15 +292,23 @@ class Typifier:
for s in statements:
self.visit(s)
case PsAssignment(lhs, rhs):
case PsDeclaration(lhs, rhs):
tc = TypeContext()
# LHS defines target type; type context carries it to RHS
self.visit_expr(lhs, tc)
assert tc.target_type is not None
self.visit_expr(rhs, tc)
case PsAssignment(lhs, rhs):
tc_lhs = TypeContext(require_nonconst=True)
self.visit_expr(lhs, tc_lhs)
assert tc_lhs.target_type is not None
tc_rhs = TypeContext(tc_lhs.target_type, require_nonconst=False)
self.visit_expr(rhs, tc_rhs)
case PsConditional(cond, branch_true, branch_false):
cond_tc = TypeContext(PsBoolType(const=True))
cond_tc = TypeContext(PsBoolType())
self.visit_expr(cond, cond_tc)
self.visit(branch_true)
......@@ -226,35 +320,49 @@ class Typifier:
if ctr.symbol.dtype is None:
ctr.symbol.apply_dtype(self._ctx.index_dtype)
tc = TypeContext(ctr.symbol.dtype)
self.visit_expr(start, tc)
self.visit_expr(stop, tc)
self.visit_expr(step, tc)
tc_index = TypeContext(ctr.symbol.dtype)
self.visit_expr(start, tc_index)
self.visit_expr(stop, tc_index)
self.visit_expr(step, tc_index)
self.visit(body)
case PsComment():
pass
case _:
raise NotImplementedError(f"Can't typify {node}")
def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None:
"""Recursive processing of expression nodes"""
"""Recursive processing of expression nodes.
This method opens, expands, and closes typing contexts according to the respective expression's
typing rules. It may add or check restrictions only when opening or closing a type context.
The actual type inference and checking during context expansion are performed by the methods
of `TypeContext`. ``visit_expr`` tells the typing context how to handle an expression by calling
either ``apply_dtype`` or ``infer_dtype``.
"""
match expr:
case PsSymbolExpr(_):
if expr.dtype is None:
tc.apply_dtype(expr, self._ctx.default_dtype)
else:
tc.apply_dtype(expr, expr.dtype)
if expr.symbol.dtype is None:
expr.symbol.dtype = self._ctx.default_dtype
case PsConstantExpr(_):
tc.infer_dtype(expr)
tc.apply_dtype(expr.symbol.dtype, expr)
case PsConstantExpr(c):
if c.dtype is not None:
tc.apply_dtype(c.dtype, expr)
else:
tc.infer_dtype(expr)
case PsArrayAccess(bptr, idx):
tc.apply_dtype(expr, bptr.array.element_type)
tc.apply_dtype(bptr.array.element_type, expr)
index_tc = TypeContext()
self.visit_expr(idx, index_tc)
if index_tc.target_type is None:
index_tc.apply_dtype(idx, self._ctx.index_dtype)
index_tc.apply_dtype(self._ctx.index_dtype, idx)
elif not isinstance(index_tc.target_type, PsIntegerType):
raise TypificationError(
f"Array index is not of integer type: {idx} has type {index_tc.target_type}"
......@@ -269,12 +377,12 @@ class Typifier:
"Type of subscript base is not subscriptable."
)
tc.apply_dtype(expr, arr_tc.target_type.base_type)
tc.apply_dtype(arr_tc.target_type.base_type, expr)
index_tc = TypeContext()
self.visit_expr(idx, index_tc)
if index_tc.target_type is None:
index_tc.apply_dtype(idx, self._ctx.index_dtype)
index_tc.apply_dtype(self._ctx.index_dtype, idx)
elif not isinstance(index_tc.target_type, PsIntegerType):
raise TypificationError(
f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}"
......@@ -289,7 +397,7 @@ class Typifier:
"Type of argument to a Deref is not dereferencable"
)
tc.apply_dtype(expr, ptr_tc.target_type.base_type)
tc.apply_dtype(ptr_tc.target_type.base_type, expr)
case PsAddressOf(arg):
arg_tc = TypeContext()
......@@ -301,10 +409,11 @@ class Typifier:
)
ptr_type = PsPointerType(arg_tc.target_type, True)
tc.apply_dtype(expr, ptr_type)
tc.apply_dtype(ptr_type, expr)
case PsLookup(aggr, member_name):
aggr_tc = TypeContext(None)
# Members of a struct type inherit the struct type's `const` qualifier
aggr_tc = TypeContext(None, require_nonconst=tc.require_nonconst)
self.visit_expr(aggr, aggr_tc)
aggr_type = aggr_tc.target_type
......@@ -319,13 +428,39 @@ class Typifier:
f"Aggregate of type {aggr_type} does not have a member {member}."
)
tc.apply_dtype(expr, member.dtype)
member_type = member.dtype
if aggr_type.const:
member_type = constify(member_type)
tc.apply_dtype(member_type, expr)
case PsRel(op1, op2):
args_tc = TypeContext()
self.visit_expr(op1, args_tc)
self.visit_expr(op2, args_tc)
if args_tc.target_type is None:
raise TypificationError(
f"Unable to determine type of arguments to relation: {expr}"
)
if not isinstance(args_tc.target_type, PsNumericType):
raise TypificationError(
f"Invalid type in arguments to relation\n"
f" Expression: {expr}\n"
f" Arguments Type: {args_tc.target_type}"
)
tc.apply_dtype(PsBoolType(), expr)
case PsBinOp(op1, op2):
self.visit_expr(op1, tc)
self.visit_expr(op2, tc)
tc.infer_dtype(expr)
case PsNeg(op) | PsNot(op):
self.visit_expr(op, tc)
tc.infer_dtype(expr)
case PsCall(function, args):
match function:
case PsMathFunction():
......@@ -358,14 +493,14 @@ class Typifier:
f"{len(items)} items as {tc.target_type}"
)
else:
items_tc.apply_dtype(None, tc.target_type.base_type)
items_tc.apply_dtype(tc.target_type.base_type)
else:
arr_type = PsArrayType(items_tc.target_type, len(items))
tc.apply_dtype(expr, arr_type)
tc.apply_dtype(arr_type, expr)
case PsCast(dtype, arg):
self.visit_expr(arg, TypeContext())
tc.apply_dtype(expr, dtype)
tc.apply_dtype(dtype, expr)
case _:
raise NotImplementedError(f"Can't typify {expr}")
......@@ -7,6 +7,7 @@ from ...types import PsType, PsIeeeFloatType
from .platform import Platform
from ..exceptions import MaterializationError
from ..kernelcreation import AstFactory
from ..kernelcreation.iteration_space import (
IterationSpace,
FullIterationSpace,
......@@ -76,28 +77,17 @@ class GenericCpu(Platform):
def _create_domain_loops(
self, body: PsBlock, ispace: FullIterationSpace
) -> PsBlock:
dimensions = ispace.dimensions
factory = AstFactory(self._ctx)
# Determine loop order by permuting dimensions
archetype_field = ispace.archetype_field
if archetype_field is not None:
loop_order = archetype_field.layout
dimensions = [dimensions[coordinate] for coordinate in loop_order]
outer_block = body
for dimension in dimensions[::-1]:
loop = PsLoop(
PsSymbolExpr(dimension.counter),
dimension.start,
dimension.stop,
dimension.step,
outer_block,
)
outer_block = PsBlock([loop])
else:
loop_order = None
return outer_block
loops = factory.loops_from_ispace(ispace, body, loop_order)
return PsBlock([loops])
def _create_sparse_loop(self, body: PsBlock, ispace: SparseIterationSpace):
mappings = [
......
......@@ -14,7 +14,7 @@ from ..ast.expressions import (
PsSymbolExpr,
PsAdd,
)
from ..ast.logical_expressions import PsLt, PsAnd
from ..ast.expressions import PsLt, PsAnd
from ...types import PsSignedIntegerType
from ..symbols import PsSymbol
......@@ -56,8 +56,10 @@ class GenericGpu(Platform):
]
return indices[:dim]
def select_function(self, math_function: PsMathFunction, dtype: PsType) -> CFunction:
def select_function(
self, math_function: PsMathFunction, dtype: PsType
) -> CFunction:
raise NotImplementedError()
# Internals
......
......@@ -42,7 +42,9 @@ class PsSymbol:
def get_dtype(self) -> PsType:
if self._dtype is None:
raise PsInternalCompilerError("Symbol had no type assigned yet")
raise PsInternalCompilerError(
f"Symbol {self.name} had no type assigned yet"
)
return self._dtype
def __str__(self) -> str:
......
"""
This module contains various transformation and optimization passes that can be
executed on the backend AST.
Canonical Form
==============
Many transformations in this module require that their input AST is in *canonical form*.
This means that:
- Each symbol, constant, and expression node is annotated with a data type;
- Each symbol has at most one declaration;
- Each symbol that is never written to apart from its declaration has a ``const`` type; and
- Each symbol whose type is *not* ``const`` has at least one non-declaring assignment.
The first requirement can be ensured by running the `Typifier` on each newly constructed subtree.
The other three requirements are ensured by the `CanonicalizeSymbols` pass,
which should be run first before applying any optimizing transformations.
All transformations in this module retain canonicality of the AST.
Canonicality allows transformations to forego various checks that would otherwise be necessary
to prove their legality.
Certain transformations, like the auto-vectorizer (TODO), state additional requirements, e.g.
the absence of loop-carried dependencies.
Transformations
===============
Canonicalization
----------------
.. autoclass:: CanonicalizeSymbols
:members: __call__
AST Cloning
-----------
.. autoclass:: CanonicalClone
:members: __call__
Simplifying Transformations
---------------------------
.. autoclass:: EliminateConstants
:members: __call__
.. autoclass:: EliminateBranches
:members: __call__
Code Motion
-----------
.. autoclass:: HoistLoopInvariantDeclarations
:members: __call__
Loop Reshaping Transformations
------------------------------
.. autoclass:: ReshapeLoops
:members:
Code Lowering and Materialization
---------------------------------
.. autoclass:: EraseAnonymousStructTypes
:members: __call__
.. autoclass:: SelectFunctions
:members: __call__
"""
from .canonicalize_symbols import CanonicalizeSymbols
from .canonical_clone import CanonicalClone
from .eliminate_constants import EliminateConstants
from .eliminate_branches import EliminateBranches
from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations
from .reshape_loops import ReshapeLoops
from .erase_anonymous_structs import EraseAnonymousStructTypes
from .select_functions import SelectFunctions
from .select_intrinsics import MaterializeVectorIntrinsics
__all__ = [
"CanonicalizeSymbols",
"CanonicalClone",
"EliminateConstants",
"EliminateBranches",
"HoistLoopInvariantDeclarations",
"ReshapeLoops",
"EraseAnonymousStructTypes",
"SelectFunctions",
"MaterializeVectorIntrinsics",
......