Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 3461 additions and 0 deletions
from dataclasses import dataclass
from typing import cast
from functools import reduce
import operator
from .structural import (
PsAssignment,
PsAstNode,
PsBlock,
PsEmptyLeafMixIn,
PsConditional,
PsDeclaration,
PsExpression,
PsLoop,
PsStatement,
)
from .expressions import (
PsAdd,
PsBufferAcc,
PsCall,
PsConstantExpr,
PsDiv,
PsIntDiv,
PsLiteralExpr,
PsMul,
PsNeg,
PsRem,
PsSub,
PsSymbolExpr,
PsTernary,
PsSubscript,
PsMemAcc,
)
from ..memory import PsSymbol
from ..exceptions import PsInternalCompilerError
from ...types import PsNumericType
from ...types.exception import PsTypeError
class UndefinedSymbolsCollector:
"""Collect undefined symbols.
This class implements an AST visitor that collects all symbols that have been used
in the AST without being defined prior to their usage.
"""
def __call__(self, node: PsAstNode) -> set[PsSymbol]:
"""Returns all symbols that occur in the given AST without being defined prior to their usage."""
return self.visit(node)
def visit(self, node: PsAstNode) -> set[PsSymbol]:
undefined_vars: set[PsSymbol] = set()
match node:
case PsExpression():
return self.visit_expr(node)
case PsStatement(expr):
return self.visit_expr(expr)
case PsAssignment(lhs, rhs):
undefined_vars = self(lhs) | self(rhs)
if isinstance(lhs, PsSymbolExpr):
undefined_vars.remove(lhs.symbol)
return undefined_vars
case PsBlock(statements):
for stmt in statements[::-1]:
undefined_vars -= self.declared_variables(stmt)
undefined_vars |= self(stmt)
return undefined_vars
case PsLoop(ctr, start, stop, step, body):
undefined_vars = self(start) | self(stop) | self(step) | self(body)
undefined_vars.discard(ctr.symbol)
return undefined_vars
case PsConditional(cond, branch_true, branch_false):
undefined_vars = self(cond) | self(branch_true)
if branch_false is not None:
undefined_vars |= self(branch_false)
return undefined_vars
case PsEmptyLeafMixIn():
return set()
case unknown:
raise PsInternalCompilerError(
f"Don't know how to collect undefined variables from {unknown}"
)
def visit_expr(self, expr: PsExpression) -> set[PsSymbol]:
match expr:
case PsSymbolExpr(symb):
return {symb}
case _:
return reduce(
set.union,
(self.visit_expr(cast(PsExpression, c)) for c in expr.children),
set(),
)
def declared_variables(self, node: PsAstNode) -> set[PsSymbol]:
"""Returns the set of variables declared by the given node which are visible in the enclosing scope."""
match node:
case PsDeclaration():
return {node.declared_symbol}
case (
PsAssignment()
| PsBlock()
| PsConditional()
| PsExpression()
| PsLoop()
| PsStatement()
| PsEmptyLeafMixIn()
):
return set()
case unknown:
raise PsInternalCompilerError(
f"Don't know how to collect declared variables from {unknown}"
)
def collect_undefined_symbols(node: PsAstNode) -> set[PsSymbol]:
return UndefinedSymbolsCollector()(node)
def collect_required_headers(node: PsAstNode) -> set[str]:
match node:
case PsSymbolExpr(symb):
return symb.get_dtype().required_headers
case PsConstantExpr(cs):
return cs.get_dtype().required_headers
case _:
return reduce(
set.union, (collect_required_headers(c) for c in node.children), set()
)
@dataclass
class OperationCounts:
float_adds: int = 0
float_muls: int = 0
float_divs: int = 0
int_adds: int = 0
int_muls: int = 0
int_divs: int = 0
calls: int = 0
branches: int = 0
loops_with_dynamic_bounds: int = 0
def __add__(self, other):
if not isinstance(other, OperationCounts):
return NotImplemented
return OperationCounts(
float_adds=self.float_adds + other.float_adds,
float_muls=self.float_muls + other.float_muls,
float_divs=self.float_divs + other.float_divs,
int_adds=self.int_adds + other.int_adds,
int_muls=self.int_muls + other.int_muls,
int_divs=self.int_divs + other.int_divs,
calls=self.calls + other.calls,
branches=self.branches + other.branches,
loops_with_dynamic_bounds=self.loops_with_dynamic_bounds
+ other.loops_with_dynamic_bounds,
)
def __rmul__(self, other):
if not isinstance(other, int):
return NotImplemented
return OperationCounts(
float_adds=other * self.float_adds,
float_muls=other * self.float_muls,
float_divs=other * self.float_divs,
int_adds=other * self.int_adds,
int_muls=other * self.int_muls,
int_divs=other * self.int_divs,
calls=other * self.calls,
branches=other * self.branches,
loops_with_dynamic_bounds=other * self.loops_with_dynamic_bounds,
)
class OperationCounter:
"""Counts the number of operations in an AST.
Assumes that the AST is typed. It is recommended that constant folding is
applied prior to this pass.
The counted operations are:
- Additions, multiplications and divisions of floating and integer type.
The counts of either type are reported separately and operations on
other types are ignored.
- Function calls.
- Branches.
Includes `PsConditional` and `PsTernary`. The operations in all branches
are summed up (i.e. the result is an overestimation).
- Loops with an unknown number of iterations.
The operations in the loop header and body are counted exactly once,
i.e. it is assumed that there is one loop iteration.
If the start, stop and step of the loop are `PsConstantExpr`, then any
operation within the body is multiplied by the number of iterations.
"""
def __call__(self, node: PsAstNode) -> OperationCounts:
"""Counts the number of operations in the given AST."""
return self.visit(node)
def visit(self, node: PsAstNode) -> OperationCounts:
match node:
case PsExpression():
return self.visit_expr(node)
case PsStatement(expr):
return self.visit_expr(expr)
case PsAssignment(lhs, rhs):
return self.visit_expr(lhs) + self.visit_expr(rhs)
case PsBlock(statements):
return reduce(
operator.add, (self.visit(s) for s in statements), OperationCounts()
)
case PsLoop(_, start, stop, step, body):
if (
isinstance(start, PsConstantExpr)
and isinstance(stop, PsConstantExpr)
and isinstance(step, PsConstantExpr)
):
val_start = start.constant.value
val_stop = stop.constant.value
val_step = step.constant.value
if (val_stop - val_start) % val_step == 0:
iteration_count = max(0, int((val_stop - val_start) / val_step))
else:
iteration_count = max(
0, int((val_stop - val_start) / val_step) + 1
)
return self.visit_expr(start) + iteration_count * (
OperationCounts(int_adds=1) # loop counter increment
+ self.visit_expr(stop)
+ self.visit_expr(step)
+ self.visit(body)
)
else:
return (
OperationCounts(loops_with_dynamic_bounds=1)
+ self.visit_expr(start)
+ self.visit_expr(stop)
+ self.visit_expr(step)
+ self.visit(body)
)
case PsConditional(cond, branch_true, branch_false):
op_counts = (
OperationCounts(branches=1)
+ self.visit(cond)
+ self.visit(branch_true)
)
if branch_false is not None:
op_counts += self.visit(branch_false)
return op_counts
case PsEmptyLeafMixIn():
return OperationCounts()
case unknown:
raise PsInternalCompilerError(f"Can't count operations in {unknown}")
def visit_expr(self, expr: PsExpression) -> OperationCounts:
match expr:
case PsSymbolExpr(_) | PsConstantExpr(_) | PsLiteralExpr(_):
return OperationCounts()
case PsBufferAcc(_, indices) | PsSubscript(_, indices):
return reduce(operator.add, (self.visit_expr(idx) for idx in indices))
case PsMemAcc(_, offset):
return self.visit_expr(offset)
case PsCall(_, args):
return OperationCounts(calls=1) + reduce(
operator.add, (self.visit(a) for a in args), OperationCounts()
)
case PsTernary(cond, then, els):
return (
OperationCounts(branches=1)
+ self.visit_expr(cond)
+ self.visit_expr(then)
+ self.visit_expr(els)
)
case PsNeg(arg):
if expr.dtype is None:
raise PsTypeError(f"Untyped arithmetic expression: {expr}")
op_counts = self.visit_expr(arg)
if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
op_counts.float_muls += 1
elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
op_counts.int_muls += 1
return op_counts
case PsAdd(arg1, arg2) | PsSub(arg1, arg2):
if expr.dtype is None:
raise PsTypeError(f"Untyped arithmetic expression: {expr}")
op_counts = self.visit_expr(arg1) + self.visit_expr(arg2)
if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
op_counts.float_adds += 1
elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
op_counts.int_adds += 1
return op_counts
case PsMul(arg1, arg2):
if expr.dtype is None:
raise PsTypeError(f"Untyped arithmetic expression: {expr}")
op_counts = self.visit_expr(arg1) + self.visit_expr(arg2)
if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
op_counts.float_muls += 1
elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
op_counts.int_muls += 1
return op_counts
case PsDiv(arg1, arg2) | PsIntDiv(arg1, arg2) | PsRem(arg1, arg2):
if expr.dtype is None:
raise PsTypeError(f"Untyped arithmetic expression: {expr}")
op_counts = self.visit_expr(arg1) + self.visit_expr(arg2)
if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
op_counts.float_divs += 1
elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
op_counts.int_divs += 1
return op_counts
case _:
return reduce(
operator.add,
(self.visit_expr(cast(PsExpression, c)) for c in expr.children),
OperationCounts(),
)
from __future__ import annotations
from typing import Sequence
from abc import ABC, abstractmethod
class PsAstNode(ABC):
"""Base class for all nodes in the pystencils AST.
This base class provides a common interface to inspect and update the AST's branching structure.
The two methods `get_children` and `set_child` must be implemented by each subclass.
Subclasses are also responsible for doing the necessary type checks if they place restrictions on
the types of their children.
"""
@property
def children(self) -> Sequence[PsAstNode]:
return self.get_children()
@children.setter
def children(self, cs: Sequence[PsAstNode]):
for i, c in enumerate(cs):
self.set_child(i, c)
@abstractmethod
def get_children(self) -> tuple[PsAstNode, ...]:
"""Retrieve child nodes of this AST node
This operation must be implemented by subclasses.
"""
pass
@abstractmethod
def set_child(self, idx: int, c: PsAstNode):
"""Update a child node of this AST node.
This operation must be implemented by subclasses.
"""
pass
@abstractmethod
def clone(self) -> PsAstNode:
"""Perform a deep copy of the AST."""
pass
def structurally_equal(self, other: PsAstNode) -> bool:
"""Check two ASTs for structural equality.
By default this method checks the node's type and children.
If an AST node has additional internal state, it MUST override this method.
"""
return (
(type(self) is type(other))
and len(self.children) == len(other.children)
and all(
c1.structurally_equal(c2)
for c1, c2 in zip(self.children, other.children)
)
)
def __str__(self) -> str:
from ..emission import emit_ir
return emit_ir(self)
class PsLeafMixIn(ABC):
"""Mix-in for AST leaves."""
def get_children(self) -> tuple[PsAstNode, ...]:
return ()
def set_child(self, idx: int, c: PsAstNode):
raise IndexError("Child index out of bounds: Leaf nodes have no children.")
@abstractmethod
def structurally_equal(self, other: PsAstNode) -> bool:
pass
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Sequence, overload, Callable, Any, cast
import operator
import numpy as np
from numpy.typing import NDArray
from ..memory import PsSymbol, PsBuffer, BufferBasePtr
from ..constants import PsConstant
from ..literals import PsLiteral
from ..functions import PsFunction
from ...types import PsType
from .util import failing_cast
from ..exceptions import PsInternalCompilerError
from .astnode import PsAstNode, PsLeafMixIn
class PsExpression(PsAstNode, ABC):
"""Base class for all expressions.
**Types:** Each expression should be annotated with its type.
Upon construction, the `dtype <PsExpression.dtype>` property of most expression nodes is unset;
only constant expressions, symbol expressions, and array accesses immediately inherit their type from
their constant, symbol, or array, respectively.
The canonical way to add types to newly constructed expressions is through the `Typifier`.
It should be run at least once on any expression constructed by the backend.
The type annotations are used by various transformation passes to make decisions, e.g. in
function materialization and intrinsic selection.
.. attention::
The ``structurally_equal <PsAstNode.structurally_equal>`` check currently does not
take expression data types into account. This may change in the future.
"""
def __init__(self, dtype: PsType | None = None) -> None:
self._dtype = dtype
@property
def dtype(self) -> PsType | None:
"""Data type assigned to this expression"""
return self._dtype
@dtype.setter
def dtype(self, dt: PsType):
self._dtype = dt
def get_dtype(self) -> PsType:
"""Retrieve the data type assigned to this expression.
Raises:
PsInternalCompilerError: If this expression has no data type assigned
"""
if self._dtype is None:
raise PsInternalCompilerError(f"No data type set on expression {self}.")
return self._dtype
def __add__(self, other: PsExpression) -> PsAdd:
if not isinstance(other, PsExpression):
return NotImplemented
return PsAdd(self, other)
def __sub__(self, other: PsExpression) -> PsSub:
if not isinstance(other, PsExpression):
return NotImplemented
return PsSub(self, other)
def __mul__(self, other: PsExpression) -> PsMul:
if not isinstance(other, PsExpression):
return NotImplemented
return PsMul(self, other)
def __truediv__(self, other: PsExpression) -> PsDiv:
if not isinstance(other, PsExpression):
return NotImplemented
return PsDiv(self, other)
def __neg__(self) -> PsNeg:
return PsNeg(self)
@overload
@staticmethod
def make(obj: PsSymbol) -> PsSymbolExpr:
pass
@overload
@staticmethod
def make(obj: PsConstant) -> PsConstantExpr:
pass
@overload
@staticmethod
def make(obj: PsLiteral) -> PsLiteralExpr:
pass
@staticmethod
def make(obj: PsSymbol | PsConstant | PsLiteral) -> PsExpression:
if isinstance(obj, PsSymbol):
return PsSymbolExpr(obj)
elif isinstance(obj, PsConstant):
return PsConstantExpr(obj)
elif isinstance(obj, PsLiteral):
return PsLiteralExpr(obj)
else:
raise ValueError(f"Cannot make expression out of {obj}")
def clone(self):
"""Clone this expression.
.. note::
Subclasses of `PsExpression` should not override this method,
but implement `_clone_expr` instead.
That implementation shall call `clone` on any of its subexpressions,
but does not need to fix the `dtype <PsExpression.dtype>` property.
The ``dtype`` is correctly applied by `PsExpression.clone` internally.
"""
cloned = self._clone_expr()
cloned._dtype = self.dtype
return cloned
@abstractmethod
def _clone_expr(self) -> PsExpression:
"""Implementation of expression cloning.
:meta public:
"""
pass
class PsLvalue(ABC):
"""Mix-in for all expressions that may occur as an lvalue;
i.e. expressions that represent a memory location."""
class PsSymbolExpr(PsLeafMixIn, PsLvalue, PsExpression):
"""A single symbol as an expression."""
__match_args__ = ("symbol",)
def __init__(self, symbol: PsSymbol):
super().__init__(symbol.dtype)
self._symbol = symbol
@property
def symbol(self) -> PsSymbol:
return self._symbol
@symbol.setter
def symbol(self, symbol: PsSymbol):
self._symbol = symbol
def _clone_expr(self) -> PsSymbolExpr:
return PsSymbolExpr(self._symbol)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsSymbolExpr):
return False
return self._symbol == other._symbol
def __repr__(self) -> str:
return f"Symbol({repr(self._symbol)})"
class PsConstantExpr(PsLeafMixIn, PsExpression):
__match_args__ = ("constant",)
def __init__(self, constant: PsConstant):
super().__init__(constant.dtype)
self._constant = constant
@property
def constant(self) -> PsConstant:
return self._constant
@constant.setter
def constant(self, c: PsConstant):
self._constant = c
def _clone_expr(self) -> PsConstantExpr:
return PsConstantExpr(self._constant)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsConstantExpr):
return False
return self._constant == other._constant
def __repr__(self) -> str:
return f"PsConstantExpr({repr(self._constant)})"
class PsLiteralExpr(PsLeafMixIn, PsExpression):
__match_args__ = ("literal",)
def __init__(self, literal: PsLiteral):
super().__init__(literal.dtype)
self._literal = literal
@property
def literal(self) -> PsLiteral:
return self._literal
@literal.setter
def literal(self, lit: PsLiteral):
self._literal = lit
def _clone_expr(self) -> PsLiteralExpr:
return PsLiteralExpr(self._literal)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsLiteralExpr):
return False
return self._literal == other._literal
def __repr__(self) -> str:
return f"PsLiteralExpr({repr(self._literal)})"
class PsBufferAcc(PsLvalue, PsExpression):
"""Access into a `PsBuffer`."""
__match_args__ = ("base_pointer", "index")
def __init__(self, base_ptr: PsSymbol, index: Sequence[PsExpression]):
super().__init__()
bptr_prop = cast(BufferBasePtr, base_ptr.get_properties(BufferBasePtr).pop())
if len(index) != bptr_prop.buffer.dim:
raise ValueError("Number of index expressions must equal buffer shape.")
self._base_ptr = PsExpression.make(base_ptr)
self._index = list(index)
self._dtype = bptr_prop.buffer.element_type
@property
def base_pointer(self) -> PsSymbolExpr:
return self._base_ptr
@base_pointer.setter
def base_pointer(self, expr: PsSymbolExpr):
bptr_prop = cast(BufferBasePtr, expr.symbol.get_properties(BufferBasePtr).pop())
if bptr_prop.buffer != self.buffer:
raise ValueError(
"Cannot replace a buffer access's base pointer with one belonging to a different buffer."
)
self._base_ptr = expr
@property
def buffer(self) -> PsBuffer:
return cast(
BufferBasePtr, self._base_ptr.symbol.get_properties(BufferBasePtr).pop()
).buffer
@property
def index(self) -> list[PsExpression]:
return self._index
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._base_ptr,) + tuple(self._index)
def set_child(self, idx: int, c: PsAstNode):
idx = range(len(self._index) + 1)[idx]
if idx == 0:
self.base_pointer = failing_cast(PsSymbolExpr, c)
else:
self._index[idx - 1] = failing_cast(PsExpression, c)
def _clone_expr(self) -> PsBufferAcc:
return PsBufferAcc(self._base_ptr.symbol, [i.clone() for i in self._index])
def __repr__(self) -> str:
return f"PsBufferAcc({repr(self._base_ptr)}, {repr(self._index)})"
class PsSubscript(PsLvalue, PsExpression):
"""N-dimensional subscript into an array."""
__match_args__ = ("array", "index")
def __init__(self, arr: PsExpression, index: Sequence[PsExpression]):
super().__init__()
self._arr = arr
if not index:
raise ValueError("Subscript index cannot be empty.")
self._index = list(index)
@property
def array(self) -> PsExpression:
return self._arr
@array.setter
def array(self, expr: PsExpression):
self._arr = expr
@property
def index(self) -> list[PsExpression]:
return self._index
@index.setter
def index(self, idx: Sequence[PsExpression]):
self._index = list(idx)
def _clone_expr(self) -> PsSubscript:
return PsSubscript(self._arr.clone(), [i.clone() for i in self._index])
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._arr,) + tuple(self._index)
def set_child(self, idx: int, c: PsAstNode):
idx = range(len(self._index) + 1)[idx]
match idx:
case 0:
self.array = failing_cast(PsExpression, c)
case _:
self.index[idx - 1] = failing_cast(PsExpression, c)
def __repr__(self) -> str:
idx = ", ".join(repr(i) for i in self._index)
return f"PsSubscript({repr(self._arr)}, {repr(idx)})"
class PsMemAcc(PsLvalue, PsExpression):
"""Pointer-based memory access with type-dependent offset."""
__match_args__ = ("pointer", "offset")
def __init__(self, ptr: PsExpression, offset: PsExpression):
super().__init__()
self._ptr = ptr
self._offset = offset
@property
def pointer(self) -> PsExpression:
return self._ptr
@pointer.setter
def pointer(self, expr: PsExpression):
self._ptr = expr
@property
def offset(self) -> PsExpression:
return self._offset
@offset.setter
def offset(self, expr: PsExpression):
self._offset = expr
def _clone_expr(self) -> PsMemAcc:
return PsMemAcc(self._ptr.clone(), self._offset.clone())
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._ptr, self._offset)
def set_child(self, idx: int, c: PsAstNode):
idx = [0, 1][idx]
match idx:
case 0:
self.pointer = failing_cast(PsExpression, c)
case 1:
self.offset = failing_cast(PsExpression, c)
def __repr__(self) -> str:
return f"PsMemAcc({repr(self._ptr)}, {repr(self._offset)})"
class PsLookup(PsExpression, PsLvalue):
__match_args__ = ("aggregate", "member_name")
def __init__(self, aggregate: PsExpression, member_name: str) -> None:
super().__init__()
self._aggregate = aggregate
self._member_name = member_name
@property
def aggregate(self) -> PsExpression:
return self._aggregate
@aggregate.setter
def aggregate(self, aggr: PsExpression):
self._aggregate = aggr
@property
def member_name(self) -> str:
return self._member_name
@member_name.setter
def member_name(self, name: str):
self._name = name
def _clone_expr(self) -> PsLookup:
return PsLookup(self._aggregate.clone(), self._member_name)
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._aggregate,)
def set_child(self, idx: int, c: PsAstNode):
idx = [0][idx]
self._aggregate = failing_cast(PsExpression, c)
def __repr__(self) -> str:
return f"PsLookup({repr(self._aggregate)}, {repr(self._member_name)})"
class PsCall(PsExpression):
__match_args__ = ("function", "args")
def __init__(self, function: PsFunction, args: Sequence[PsExpression]) -> None:
if len(args) != function.arg_count:
raise ValueError(
f"Argument count mismatch: Cannot apply function {function} to {len(args)} arguments."
)
super().__init__()
self._function = function
self._args = list(args)
@property
def function(self) -> PsFunction:
return self._function
@function.setter
def function(self, func: PsFunction):
if func.arg_count != self._function.arg_count:
raise ValueError(
"Current and replacement function must have the same number of parameters."
)
self._function = func
@property
def args(self) -> tuple[PsExpression, ...]:
return tuple(self._args)
@args.setter
def args(self, exprs: Sequence[PsExpression]):
if len(exprs) != self._function.arg_count:
raise ValueError(
f"Argument count mismatch: Cannot apply function {self._function} to {len(exprs)} arguments."
)
self._args = list(exprs)
def _clone_expr(self) -> PsCall:
return PsCall(self._function, [arg.clone() for arg in self._args])
def get_children(self) -> tuple[PsAstNode, ...]:
return self.args
def set_child(self, idx: int, c: PsAstNode):
self._args[idx] = failing_cast(PsExpression, c)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsCall):
return False
return super().structurally_equal(other) and self._function == other._function
def __repr__(self):
args = ", ".join(repr(arg) for arg in self._args)
return f"PsCall({repr(self._function)}, ({args}))"
class PsTernary(PsExpression):
"""Ternary operator."""
__match_args__ = ("condition", "case_then", "case_else")
def __init__(
self, cond: PsExpression, then: PsExpression, els: PsExpression
) -> None:
super().__init__()
self._cond = cond
self._then = then
self._else = els
@property
def condition(self) -> PsExpression:
return self._cond
@property
def case_then(self) -> PsExpression:
return self._then
@property
def case_else(self) -> PsExpression:
return self._else
def _clone_expr(self) -> PsExpression:
return PsTernary(self._cond.clone(), self._then.clone(), self._else.clone())
def get_children(self) -> tuple[PsExpression, ...]:
return (self._cond, self._then, self._else)
def set_child(self, idx: int, c: PsAstNode):
idx = range(3)[idx]
match idx:
case 0:
self._cond = failing_cast(PsExpression, c)
case 1:
self._then = failing_cast(PsExpression, c)
case 2:
self._else = failing_cast(PsExpression, c)
def __repr__(self) -> str:
return f"PsTernary({repr(self._cond)}, {repr(self._then)}, {repr(self._else)})"
class PsNumericOpTrait:
"""Trait for operations valid only on numerical types"""
class PsIntOpTrait:
"""Trait for operations valid only on integer types"""
class PsBoolOpTrait:
"""Trait for boolean operations"""
class PsUnOp(PsExpression):
__match_args__ = ("operand",)
def __init__(self, operand: PsExpression):
super().__init__()
self._operand = operand
@property
def operand(self) -> PsExpression:
return self._operand
@operand.setter
def operand(self, expr: PsExpression):
self._operand = expr
def _clone_expr(self) -> PsUnOp:
return type(self)(self._operand.clone())
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._operand,)
def set_child(self, idx: int, c: PsAstNode):
idx = [0][idx]
self._operand = failing_cast(PsExpression, c)
@property
def python_operator(self) -> None | Callable[[Any], Any]:
return None
def __repr__(self) -> str:
opname = self.__class__.__name__
return f"{opname}({repr(self._operand)})"
class PsNeg(PsUnOp, PsNumericOpTrait):
@property
def python_operator(self):
return operator.neg
class PsAddressOf(PsUnOp):
"""Take the address of a memory location.
.. DANGER::
Taking the address of a memory location owned by a symbol or field array
introduces an alias to that memory location.
As pystencils assumes its symbols and fields to never be aliased, this can
subtly change the semantics of a kernel.
Use the address-of operator with utmost care.
"""
pass
class PsCast(PsUnOp):
__match_args__ = ("target_type", "operand")
def __init__(self, target_type: PsType, operand: PsExpression):
super().__init__(operand)
self._target_type = target_type
@property
def target_type(self) -> PsType:
return self._target_type
@target_type.setter
def target_type(self, dtype: PsType):
self._target_type = dtype
def _clone_expr(self) -> PsUnOp:
return PsCast(self._target_type, self._operand.clone())
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsCast):
return False
return (
super().structurally_equal(other)
and self._target_type == other._target_type
)
class PsBinOp(PsExpression):
__match_args__ = ("operand1", "operand2")
def __init__(self, op1: PsExpression, op2: PsExpression):
super().__init__()
self._op1 = op1
self._op2 = op2
@property
def operand1(self) -> PsExpression:
return self._op1
@operand1.setter
def operand1(self, expr: PsExpression):
self._op1 = expr
@property
def operand2(self) -> PsExpression:
return self._op2
@operand2.setter
def operand2(self, expr: PsExpression):
self._op2 = expr
def _clone_expr(self) -> PsBinOp:
return type(self)(self._op1.clone(), self._op2.clone())
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._op1, self._op2)
def set_child(self, idx: int, c: PsAstNode):
idx = [0, 1][idx]
match idx:
case 0:
self._op1 = failing_cast(PsExpression, c)
case 1:
self._op2 = failing_cast(PsExpression, c)
def __repr__(self) -> str:
opname = self.__class__.__name__
return f"{opname}({repr(self._op1)}, {repr(self._op2)})"
@property
def python_operator(self) -> None | Callable[[Any, Any], Any]:
return None
class PsAdd(PsBinOp, PsNumericOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.add
class PsSub(PsBinOp, PsNumericOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.sub
class PsMul(PsBinOp, PsNumericOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.mul
class PsDiv(PsBinOp, PsNumericOpTrait):
# python_operator not implemented because can't unambigously decide
# between intdiv and truediv
pass
class PsIntDiv(PsBinOp, PsIntOpTrait):
"""C-like integer division (round to zero)."""
@property
def python_operator(self) -> Callable[[Any, Any], Any]:
from ...utils import c_intdiv
return c_intdiv
class PsRem(PsBinOp, PsIntOpTrait):
"""C-style integer division remainder"""
@property
def python_operator(self) -> Callable[[Any, Any], Any]:
from ...utils import c_rem
return c_rem
class PsLeftShift(PsBinOp, PsIntOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.lshift
class PsRightShift(PsBinOp, PsIntOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.rshift
class PsBitwiseAnd(PsBinOp, PsIntOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.and_
class PsBitwiseXor(PsBinOp, PsIntOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.xor
class PsBitwiseOr(PsBinOp, PsIntOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.or_
class PsAnd(PsBinOp, PsBoolOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return np.logical_and
class PsOr(PsBinOp, PsBoolOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return np.logical_or
class PsNot(PsUnOp, PsBoolOpTrait):
@property
def python_operator(self) -> Callable[[Any], Any] | None:
return np.logical_not
class PsRel(PsBinOp):
"""Base class for binary relational operators"""
class PsEq(PsRel):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.eq
class PsNe(PsRel):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.ne
class PsGe(PsRel):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.ge
class PsLe(PsRel):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.le
class PsGt(PsRel):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.gt
class PsLt(PsRel):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.lt
class PsArrayInitList(PsExpression):
"""N-dimensional array initialization matrix."""
__match_args__ = ("items",)
def __init__(
self,
items: Sequence[PsExpression | Sequence[PsExpression | Sequence[PsExpression]]],
):
super().__init__()
self._items = np.array(items, dtype=np.object_)
@property
def items_grid(self) -> NDArray[np.object_]:
return self._items
@property
def shape(self) -> tuple[int, ...]:
return self._items.shape
@property
def items(self) -> tuple[PsExpression, ...]:
return tuple(self._items.flat) # type: ignore
def get_children(self) -> tuple[PsAstNode, ...]:
return tuple(self._items.flat) # type: ignore
def set_child(self, idx: int, c: PsAstNode):
self._items.flat[idx] = failing_cast(PsExpression, c)
def _clone_expr(self) -> PsExpression:
return PsArrayInitList(
np.array([expr.clone() for expr in self.children]).reshape( # type: ignore
self._items.shape
)
)
def __repr__(self) -> str:
return f"PsArrayInitList({repr(self._items)})"
def evaluate_expression(expr: PsExpression, valuation: dict[str, Any]) -> Any:
"""Evaluate a pystencils backend expression tree with values assigned to symbols according to the given valuation.
Only a subset of expression nodes can be processed by this evaluator.
"""
def visit(node):
match node:
case PsSymbolExpr(symb):
return valuation[symb.name]
case PsConstantExpr(c):
return c.value
case PsUnOp(op1) if node.python_operator is not None:
return node.python_operator(visit(op1))
case PsBinOp(op1, op2) if node.python_operator is not None:
return node.python_operator(visit(op1), visit(op2))
case other:
raise NotImplementedError(
f"Unable to evaluate {other}: No implementation available."
)
return visit(expr)
from typing import Callable, Generator
from .structural import PsAstNode
def dfs_preorder(
node: PsAstNode, filter_pred: Callable[[PsAstNode], bool] = lambda _: True
) -> Generator[PsAstNode, None, None]:
"""Pre-Order depth-first traversal of an abstract syntax tree.
Args:
node: The tree's root node
filter_pred: Filter predicate; a node is only returned to the caller if ``yield_pred(node)`` returns True
"""
if filter_pred(node):
yield node
for c in node.children:
yield from dfs_preorder(c, filter_pred)
def dfs_postorder(
node: PsAstNode, filter_pred: Callable[[PsAstNode], bool] = lambda _: True
) -> Generator[PsAstNode, None, None]:
"""Post-Order depth-first traversal of an abstract syntax tree.
Args:
node: The tree's root node
filter_pred: Filter predicate; a node is only returned to the caller if ``yield_pred(node)`` returns True
"""
for c in node.children:
yield from dfs_postorder(c, filter_pred)
if filter_pred(node):
yield node
from __future__ import annotations
from typing import Sequence, cast
from types import NoneType
from .astnode import PsAstNode, PsLeafMixIn
from .expressions import PsExpression, PsLvalue, PsSymbolExpr
from ..memory import PsSymbol
from .util import failing_cast
class PsBlock(PsAstNode):
__match_args__ = ("statements",)
def __init__(self, cs: Sequence[PsAstNode]):
self._statements = list(cs)
@property
def children(self) -> Sequence[PsAstNode]:
return self.get_children()
@children.setter
def children(self, cs: Sequence[PsAstNode]):
self._statements = list(cs)
def get_children(self) -> tuple[PsAstNode, ...]:
return tuple(self._statements)
def set_child(self, idx: int, c: PsAstNode):
self._statements[idx] = c
def clone(self) -> PsBlock:
return PsBlock([stmt.clone() for stmt in self._statements])
@property
def statements(self) -> list[PsAstNode]:
return self._statements
@statements.setter
def statements(self, stm: Sequence[PsAstNode]):
self._statements = list(stm)
def __repr__(self) -> str:
contents = ", ".join(repr(c) for c in self.children)
return f"PsBlock( {contents} )"
class PsStatement(PsAstNode):
__match_args__ = ("expression",)
def __init__(self, expr: PsExpression):
self._expression = expr
@property
def expression(self) -> PsExpression:
return self._expression
@expression.setter
def expression(self, expr: PsExpression):
self._expression = expr
def clone(self) -> PsStatement:
return PsStatement(self._expression.clone())
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._expression,)
def set_child(self, idx: int, c: PsAstNode):
idx = [0][idx]
assert idx == 0
self._expression = failing_cast(PsExpression, c)
class PsAssignment(PsAstNode):
__match_args__ = (
"lhs",
"rhs",
)
def __init__(self, lhs: PsExpression, rhs: PsExpression):
if not isinstance(lhs, PsLvalue):
raise ValueError("Assignment LHS must be an lvalue")
self._lhs: PsExpression = lhs
self._rhs = rhs
@property
def lhs(self) -> PsExpression:
return self._lhs
@lhs.setter
def lhs(self, lvalue: PsExpression):
if not isinstance(lvalue, PsLvalue):
raise ValueError("Assignment LHS must be an lvalue")
self._lhs = lvalue
@property
def rhs(self) -> PsExpression:
return self._rhs
@rhs.setter
def rhs(self, expr: PsExpression):
self._rhs = expr
def clone(self) -> PsAssignment:
return PsAssignment(self._lhs.clone(), self._rhs.clone())
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._lhs, self._rhs)
def set_child(self, idx: int, c: PsAstNode):
idx = [0, 1][idx] # trick to normalize index
if idx == 0:
self.lhs = failing_cast(PsExpression, c)
elif idx == 1:
self._rhs = failing_cast(PsExpression, c)
else:
assert False, "unreachable code"
def __repr__(self) -> str:
return f"PsAssignment({repr(self._lhs)}, {repr(self._rhs)})"
class PsDeclaration(PsAssignment):
__match_args__ = (
"lhs",
"rhs",
)
def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression):
super().__init__(lhs, rhs)
@property
def lhs(self) -> PsExpression:
return self._lhs
@lhs.setter
def lhs(self, lvalue: PsExpression):
self._lhs = failing_cast(PsSymbolExpr, lvalue)
@property
def declared_symbol(self) -> PsSymbol:
return cast(PsSymbolExpr, self._lhs).symbol
def clone(self) -> PsDeclaration:
return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone())
def set_child(self, idx: int, c: PsAstNode):
idx = [0, 1][idx] # trick to normalize index
if idx == 0:
self.lhs = failing_cast(PsSymbolExpr, c)
elif idx == 1:
self._rhs = failing_cast(PsExpression, c)
else:
assert False, "unreachable code"
def __repr__(self) -> str:
return f"PsDeclaration({repr(self._lhs)}, {repr(self._rhs)})"
class PsLoop(PsAstNode):
__match_args__ = ("counter", "start", "stop", "step", "body")
def __init__(
self,
ctr: PsSymbolExpr,
start: PsExpression,
stop: PsExpression,
step: PsExpression,
body: PsBlock,
):
self._ctr = ctr
self._start = start
self._stop = stop
self._step = step
self._body = body
@property
def counter(self) -> PsSymbolExpr:
return self._ctr
@counter.setter
def counter(self, expr: PsSymbolExpr):
self._ctr = expr
@property
def start(self) -> PsExpression:
return self._start
@start.setter
def start(self, expr: PsExpression):
self._start = expr
@property
def stop(self) -> PsExpression:
return self._stop
@stop.setter
def stop(self, expr: PsExpression):
self._stop = expr
@property
def step(self) -> PsExpression:
return self._step
@step.setter
def step(self, expr: PsExpression):
self._step = expr
@property
def body(self) -> PsBlock:
return self._body
@body.setter
def body(self, block: PsBlock):
self._body = block
def clone(self) -> PsLoop:
return PsLoop(
self._ctr.clone(),
self._start.clone(),
self._stop.clone(),
self._step.clone(),
self._body.clone(),
)
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._ctr, self._start, self._stop, self._step, self._body)
def set_child(self, idx: int, c: PsAstNode):
idx = list(range(5))[idx]
match idx:
case 0:
self._ctr = failing_cast(PsSymbolExpr, c)
case 1:
self._start = failing_cast(PsExpression, c)
case 2:
self._stop = failing_cast(PsExpression, c)
case 3:
self._step = failing_cast(PsExpression, c)
case 4:
self._body = failing_cast(PsBlock, c)
case _:
assert False, "unreachable code"
class PsConditional(PsAstNode):
"""Conditional branch"""
__match_args__ = ("condition", "branch_true", "branch_false")
def __init__(
self,
cond: PsExpression,
branch_true: PsBlock,
branch_false: PsBlock | None = None,
):
self._condition = cond
self._branch_true = branch_true
self._branch_false = branch_false
@property
def condition(self) -> PsExpression:
return self._condition
@condition.setter
def condition(self, expr: PsExpression):
self._condition = expr
@property
def branch_true(self) -> PsBlock:
return self._branch_true
@branch_true.setter
def branch_true(self, block: PsBlock):
self._branch_true = block
@property
def branch_false(self) -> PsBlock | None:
return self._branch_false
@branch_false.setter
def branch_false(self, block: PsBlock | None):
self._branch_false = block
def clone(self) -> PsConditional:
return PsConditional(
self._condition.clone(),
self._branch_true.clone(),
self._branch_false.clone() if self._branch_false is not None else None,
)
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._condition, self._branch_true) + (
(self._branch_false,) if self._branch_false is not None else ()
)
def set_child(self, idx: int, c: PsAstNode):
idx = list(range(3))[idx]
match idx:
case 0:
self._condition = failing_cast(PsExpression, c)
case 1:
self._branch_true = failing_cast(PsBlock, c)
case 2:
self._branch_false = failing_cast((PsBlock, NoneType), c)
case _:
assert False, "unreachable code"
def __repr__(self) -> str:
return f"PsConditional({repr(self._condition)}, {repr(self._branch_true)}, {repr(self._branch_false)})"
class PsEmptyLeafMixIn:
"""Mix-in marking AST leaves that can be treated as empty by the code generator,
such as comments and preprocessor directives."""
pass
class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
"""A C/C++ preprocessor pragma.
Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``.
Args:
text: The pragma's text, without the ``#pragma``.
"""
__match_args__ = ("text",)
def __init__(self, text: str) -> None:
self._text = text
@property
def text(self) -> str:
return self._text
def clone(self) -> PsPragma:
return PsPragma(self.text)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsPragma):
return False
return self._text == other._text
class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
__match_args__ = ("lines",)
def __init__(self, text: str) -> None:
self._text = text
self._lines = tuple(text.splitlines())
@property
def text(self) -> str:
return self._text
@property
def lines(self) -> tuple[str, ...]:
return self._lines
def clone(self) -> PsComment:
return PsComment(self._text)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsComment):
return False
return self._text == other._text
from __future__ import annotations
from typing import Any, TYPE_CHECKING, cast
from ..exceptions import PsInternalCompilerError
from ..memory import PsSymbol
from ..memory import PsBuffer
from ...types import PsDereferencableType
if TYPE_CHECKING:
from .astnode import PsAstNode
from .expressions import PsExpression
def failing_cast(target: type | tuple[type, ...], obj: Any) -> Any:
if not isinstance(obj, target):
raise TypeError(f"Casting {obj} to {target} failed.")
return obj
class AstEqWrapper:
"""Wrapper around AST nodes that computes a hash from the AST's textual representation
and maps the ``__eq__`` method onto `structurally_equal <PsAstNode.structurally_equal>`.
Useful in dictionaries when the goal is to keep track of subtrees according to their
structure, e.g. in elimination of constants or common subexpressions.
"""
def __init__(self, node: PsAstNode):
self._node = node
@property
def n(self):
return self._node
def __eq__(self, other: object) -> bool:
if not isinstance(other, AstEqWrapper):
return False
return self._node.structurally_equal(other._node)
def __hash__(self) -> int:
# TODO: consider replacing this with smth. more performant
# TODO: Check that repr is implemented by all AST nodes
return hash(repr(self._node))
def determine_memory_object(
expr: PsExpression,
) -> tuple[PsSymbol | PsBuffer | None, bool]:
"""Return the memory object accessed by the given expression, together with its constness
Returns:
Tuple ``(mem_obj, const)`` identifying the memory object accessed by the given expression,
as well as its constness
"""
from pystencils.backend.ast.expressions import (
PsSubscript,
PsLookup,
PsSymbolExpr,
PsMemAcc,
PsBufferAcc,
)
while isinstance(expr, (PsSubscript, PsLookup)):
match expr:
case PsSubscript(arr, _):
expr = arr
case PsLookup(record, _):
expr = record
match expr:
case PsSymbolExpr(symb):
return symb, symb.get_dtype().const
case PsMemAcc(ptr, _):
return None, cast(PsDereferencableType, ptr.get_dtype()).base_type.const
case PsBufferAcc(ptr, _):
return (
expr.buffer,
cast(PsDereferencableType, ptr.get_dtype()).base_type.const,
)
case _:
raise PsInternalCompilerError(
"The given expression is a transient and does not refer to a memory object"
)
from __future__ import annotations
from typing import cast
from .astnode import PsAstNode
from .expressions import PsExpression, PsLvalue, PsUnOp
from .util import failing_cast
from ...types import PsVectorType
class PsVectorOp:
"""Mix-in for vector operations"""
class PsVecBroadcast(PsUnOp, PsVectorOp):
"""Broadcast a scalar value to N vector lanes."""
__match_args__ = ("lanes", "operand")
def __init__(self, lanes: int, operand: PsExpression):
super().__init__(operand)
self._lanes = lanes
@property
def lanes(self) -> int:
return self._lanes
@lanes.setter
def lanes(self, n: int):
self._lanes = n
def _clone_expr(self) -> PsVecBroadcast:
return PsVecBroadcast(self._lanes, self._operand.clone())
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsVecBroadcast):
return False
return (
super().structurally_equal(other)
and self._lanes == other._lanes
)
class PsVecMemAcc(PsExpression, PsLvalue, PsVectorOp):
"""Pointer-based vectorized memory access.
Args:
base_ptr: Pointer identifying the accessed memory region
offset: Offset inside the memory region
vector_entries: Number of elements to access
stride: Optional integer step size for strided access, or ``None`` for contiguous access
aligned: For contiguous accesses, whether the access is guaranteed to be naturally aligned
according to the vector data type
"""
__match_args__ = ("pointer", "offset", "vector_entries", "stride", "aligned")
def __init__(
self,
base_ptr: PsExpression,
offset: PsExpression,
vector_entries: int,
stride: PsExpression | None = None,
aligned: bool = False,
):
super().__init__()
self._ptr = base_ptr
self._offset = offset
self._vector_entries = vector_entries
self._stride = stride
self._aligned = aligned
@property
def pointer(self) -> PsExpression:
return self._ptr
@pointer.setter
def pointer(self, expr: PsExpression):
self._ptr = expr
@property
def offset(self) -> PsExpression:
return self._offset
@offset.setter
def offset(self, expr: PsExpression):
self._offset = expr
@property
def vector_entries(self) -> int:
return self._vector_entries
@property
def stride(self) -> PsExpression | None:
return self._stride
@stride.setter
def stride(self, expr: PsExpression | None):
self._stride = expr
@property
def aligned(self) -> bool:
return self._aligned
def get_vector_type(self) -> PsVectorType:
return cast(PsVectorType, self._dtype)
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._ptr, self._offset) + (() if self._stride is None else (self._stride,))
def set_child(self, idx: int, c: PsAstNode):
idx = [0, 1, 2][idx]
match idx:
case 0:
self._ptr = failing_cast(PsExpression, c)
case 1:
self._offset = failing_cast(PsExpression, c)
case 2:
self._stride = failing_cast(PsExpression, c)
def _clone_expr(self) -> PsVecMemAcc:
return PsVecMemAcc(
self._ptr.clone(),
self._offset.clone(),
self.vector_entries,
self._stride.clone() if self._stride is not None else None,
self._aligned,
)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsVecMemAcc):
return False
return (
super().structurally_equal(other)
and self._vector_entries == other._vector_entries
and self._aligned == other._aligned
)
def __repr__(self) -> str:
return (
f"PsVecMemAcc({repr(self._ptr)}, {repr(self._offset)}, {repr(self._vector_entries)}, "
f"stride={repr(self._stride)}, aligned={repr(self._aligned)})"
)
from __future__ import annotations
from typing import Any
import numpy as np
from ..types import PsNumericType, constify
from .exceptions import PsInternalCompilerError
class PsConstant:
"""Type-safe representation of typed numerical constants.
This class models constants in the backend representation of kernels.
A constant may be *untyped*, in which case its ``value`` may be any Python object.
If the constant is *typed* (i.e. its ``dtype`` is not ``None``), its data type is used
to check the validity of its ``value`` and to convert it into the type's internal representation.
Instances of `PsConstant` are immutable.
Args:
value: The constant's value
dtype: The constant's data type, or ``None`` if untyped.
"""
__match_args__ = ("value", "dtype")
def __init__(self, value: Any, dtype: PsNumericType | None = None):
self._dtype: PsNumericType | None = None
self._value = value
if dtype is not None:
self._dtype = constify(dtype)
self._value = self._dtype.create_constant(self._value)
else:
self._dtype = None
self._value = value
def interpret_as(self, dtype: PsNumericType) -> PsConstant:
"""Interprets this *untyped* constant with the given data type.
If this constant is already typed, raises an error.
"""
if self._dtype is not None:
raise PsInternalCompilerError(
f"Cannot interpret already typed constant {self} with type {dtype}"
)
return PsConstant(self._value, dtype)
def reinterpret_as(self, dtype: PsNumericType) -> PsConstant:
"""Reinterprets this constant with the given data type.
Other than `interpret_as`, this method also works on typed constants.
"""
return PsConstant(self._value, dtype)
@property
def value(self) -> Any:
return self._value
@property
def dtype(self) -> PsNumericType | None:
"""This constant's data type, or ``None`` if it is untyped.
The data type of a constant always has ``const == True``.
"""
return self._dtype
def get_dtype(self) -> PsNumericType:
"""Retrieve this constant's data type, throwing an exception if the constant is untyped."""
if self._dtype is None:
raise PsInternalCompilerError("Data type of constant was not set.")
return self._dtype
def __str__(self) -> str:
type_str = "<untyped>" if self._dtype is None else str(self._dtype)
return f"{str(self._value)}: {type_str}"
def __repr__(self) -> str:
return str(self)
def __hash__(self) -> int:
return hash((self._dtype, self._value))
def __eq__(self, other) -> bool:
if not isinstance(other, PsConstant):
return False
return (self._dtype == other._dtype) and bool(np.all(self._value == other._value))
from .base_printer import EmissionError
from .c_printer import emit_code, CAstPrinter
from .ir_printer import emit_ir, IRAstPrinter
__all__ = ["emit_code", "CAstPrinter", "emit_ir", "IRAstPrinter", "EmissionError"]
from __future__ import annotations
from enum import Enum
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from ..ast.structural import (
PsAstNode,
PsBlock,
PsStatement,
PsDeclaration,
PsAssignment,
PsLoop,
PsConditional,
PsComment,
PsPragma,
)
from ..ast.expressions import (
PsExpression,
PsAdd,
PsAddressOf,
PsArrayInitList,
PsBinOp,
PsBitwiseAnd,
PsBitwiseOr,
PsBitwiseXor,
PsCall,
PsCast,
PsConstantExpr,
PsMemAcc,
PsDiv,
PsRem,
PsIntDiv,
PsLeftShift,
PsLookup,
PsMul,
PsNeg,
PsRightShift,
PsSub,
PsSymbolExpr,
PsLiteralExpr,
PsTernary,
PsAnd,
PsOr,
PsNot,
PsEq,
PsNe,
PsGt,
PsLt,
PsGe,
PsLe,
PsSubscript,
)
from ..extensions.foreign_ast import PsForeignExpression
from ..memory import PsSymbol
from ..constants import PsConstant
from ...types import PsType
from ...codegen import Target
if TYPE_CHECKING:
from ...codegen import Kernel
class EmissionError(Exception):
"""Indicates a fatal error during code printing"""
class LR(Enum):
Left = 0
Right = 1
Middle = 2
class Ops(Enum):
"""Operator precedence and associativity in C/C++.
See also https://en.cppreference.com/w/cpp/language/operator_precedence
"""
Call = (2, LR.Left)
Subscript = (2, LR.Left)
Lookup = (2, LR.Left)
Neg = (3, LR.Right)
Not = (3, LR.Right)
AddressOf = (3, LR.Right)
Deref = (3, LR.Right)
Cast = (3, LR.Right)
Mul = (5, LR.Left)
Div = (5, LR.Left)
Rem = (5, LR.Left)
Add = (6, LR.Left)
Sub = (6, LR.Left)
LeftShift = (7, LR.Left)
RightShift = (7, LR.Left)
RelOp = (9, LR.Left) # >=, >, <, <=
EqOp = (10, LR.Left) # == and !=
BitwiseAnd = (11, LR.Left)
BitwiseXor = (12, LR.Left)
BitwiseOr = (13, LR.Left)
LogicAnd = (14, LR.Left)
LogicOr = (15, LR.Left)
Ternary = (16, LR.Right)
Weakest = (17, LR.Middle)
def __init__(self, pred: int, assoc: LR) -> None:
self.precedence = pred
self.assoc = assoc
class PrinterCtx:
def __init__(self) -> None:
self.operator_stack = [Ops.Weakest]
self.branch_stack = [LR.Middle]
self.indent_level = 0
def push_op(self, operator: Ops, branch: LR):
self.operator_stack.append(operator)
self.branch_stack.append(branch)
def pop_op(self) -> None:
self.operator_stack.pop()
self.branch_stack.pop()
def switch_branch(self, branch: LR):
self.branch_stack[-1] = branch
@property
def current_op(self) -> Ops:
return self.operator_stack[-1]
@property
def current_branch(self) -> LR:
return self.branch_stack[-1]
def parenthesize(self, expr: str, next_operator: Ops) -> str:
if next_operator.precedence > self.current_op.precedence:
return f"({expr})"
elif (
next_operator.precedence == self.current_op.precedence
and self.current_branch != self.current_op.assoc
):
return f"({expr})"
return expr
def indent(self, line: str) -> str:
return " " * self.indent_level + line
class BasePrinter(ABC):
"""Base code printer.
The base printer is capable of printing syntax tree nodes valid in all output dialects.
It is specialized in `CAstPrinter` for the C output language,
and in `IRAstPrinter` for debug-printing the entire IR.
"""
def __init__(self, indent_width=3, func_prefix: str | None = None):
self._indent_width = indent_width
self._func_prefix = func_prefix
def __call__(self, obj: PsAstNode | Kernel) -> str:
from ...codegen import Kernel
if isinstance(obj, Kernel):
sig = self.print_signature(obj)
body_code = self.visit(obj.body, PrinterCtx())
return f"{sig}\n{body_code}"
else:
return self.visit(obj, PrinterCtx())
def visit(self, node: PsAstNode, pc: PrinterCtx) -> str:
match node:
case PsBlock(statements):
if not statements:
return pc.indent("{ }")
pc.indent_level += self._indent_width
interior = "\n".join(self.visit(stmt, pc) for stmt in statements) + "\n"
pc.indent_level -= self._indent_width
return pc.indent("{\n") + interior + pc.indent("}")
case PsStatement(expr):
return pc.indent(f"{self.visit(expr, pc)};")
case PsDeclaration(lhs, rhs):
lhs_symb = node.declared_symbol
lhs_code = self._symbol_decl(lhs_symb)
rhs_code = self.visit(rhs, pc)
return pc.indent(f"{lhs_code} = {rhs_code};")
case PsAssignment(lhs, rhs):
lhs_code = self.visit(lhs, pc)
rhs_code = self.visit(rhs, pc)
return pc.indent(f"{lhs_code} = {rhs_code};")
case PsLoop(ctr, start, stop, step, body):
ctr_symbol = ctr.symbol
ctr_decl = self._symbol_decl(ctr_symbol)
start_code = self.visit(start, pc)
stop_code = self.visit(stop, pc)
step_code = self.visit(step, pc)
body_code = self.visit(body, pc)
code = (
f"for({ctr_decl} = {start_code};"
+ f" {ctr.symbol.name} < {stop_code};"
+ f" {ctr.symbol.name} += {step_code})\n"
+ body_code
)
return pc.indent(code)
case PsConditional(condition, branch_true, branch_false):
cond_code = self.visit(condition, pc)
then_code = self.visit(branch_true, pc)
code = f"if({cond_code})\n{then_code}"
if branch_false is not None:
else_code = self.visit(branch_false, pc)
code += f"\nelse\n{else_code}"
return pc.indent(code)
case PsComment(lines):
lines_list = list(lines)
lines_list[0] = "/* " + lines_list[0]
for i in range(1, len(lines_list)):
lines_list[i] = " " + lines_list[i]
lines_list[-1] = lines_list[-1] + " */"
return pc.indent("\n".join(lines_list))
case PsPragma(text):
return pc.indent("#pragma " + text)
case PsSymbolExpr(symbol):
return symbol.name
case PsConstantExpr(constant):
return self._constant_literal(constant)
case PsLiteralExpr(lit):
return lit.text
case PsMemAcc(base, offset):
pc.push_op(Ops.Subscript, LR.Left)
base_code = self.visit(base, pc)
pc.pop_op()
pc.push_op(Ops.Weakest, LR.Middle)
index_code = self.visit(offset, pc)
pc.pop_op()
return pc.parenthesize(f"{base_code}[{index_code}]", Ops.Subscript)
case PsSubscript(base, indices):
pc.push_op(Ops.Subscript, LR.Left)
base_code = self.visit(base, pc)
pc.pop_op()
pc.push_op(Ops.Weakest, LR.Middle)
indices_code = "".join(
"[" + self.visit(idx, pc) + "]" for idx in indices
)
pc.pop_op()
return pc.parenthesize(base_code + indices_code, Ops.Subscript)
case PsLookup(aggr, member_name):
pc.push_op(Ops.Lookup, LR.Left)
aggr_code = self.visit(aggr, pc)
pc.pop_op()
return pc.parenthesize(f"{aggr_code}.{member_name}", Ops.Lookup)
case PsCall(function, args):
pc.push_op(Ops.Weakest, LR.Middle)
args_string = ", ".join(self.visit(arg, pc) for arg in args)
pc.pop_op()
return pc.parenthesize(f"{function.name}({args_string})", Ops.Call)
case PsBinOp(op1, op2):
op_char, op = self._char_and_op(node)
pc.push_op(op, LR.Left)
op1_code = self.visit(op1, pc)
pc.switch_branch(LR.Right)
op2_code = self.visit(op2, pc)
pc.pop_op()
return pc.parenthesize(f"{op1_code} {op_char} {op2_code}", op)
case PsNeg(operand):
pc.push_op(Ops.Neg, LR.Right)
operand_code = self.visit(operand, pc)
pc.pop_op()
return pc.parenthesize(f"-{operand_code}", Ops.Neg)
case PsNot(operand):
pc.push_op(Ops.Not, LR.Right)
operand_code = self.visit(operand, pc)
pc.pop_op()
return pc.parenthesize(f"!{operand_code}", Ops.Not)
case PsAddressOf(operand):
pc.push_op(Ops.AddressOf, LR.Right)
operand_code = self.visit(operand, pc)
pc.pop_op()
return pc.parenthesize(f"&{operand_code}", Ops.AddressOf)
case PsCast(target_type, operand):
pc.push_op(Ops.Cast, LR.Right)
operand_code = self.visit(operand, pc)
pc.pop_op()
type_str = self._type_str(target_type)
return pc.parenthesize(f"({type_str}) {operand_code}", Ops.Cast)
case PsTernary(cond, then, els):
pc.push_op(Ops.Ternary, LR.Left)
cond_code = self.visit(cond, pc)
pc.switch_branch(LR.Middle)
then_code = self.visit(then, pc)
pc.switch_branch(LR.Right)
else_code = self.visit(els, pc)
pc.pop_op()
return pc.parenthesize(
f"{cond_code} ? {then_code} : {else_code}", Ops.Ternary
)
case PsArrayInitList(_):
def print_arr(item) -> str:
if isinstance(item, PsExpression):
return self.visit(item, pc)
else:
# it's a subarray
entries = ", ".join(print_arr(i) for i in item)
return "{ " + entries + " }"
pc.push_op(Ops.Weakest, LR.Middle)
arr_str = print_arr(node.items_grid)
pc.pop_op()
return arr_str
case PsForeignExpression(children):
pc.push_op(Ops.Weakest, LR.Middle)
foreign_code = node.get_code(self.visit(c, pc) for c in children)
pc.pop_op()
return foreign_code
case _:
raise NotImplementedError(
f"BasePrinter does not know how to print {type(node)}"
)
def print_signature(self, func: Kernel) -> str:
params_str = ", ".join(
f"{self._type_str(p.dtype)} {p.name}" for p in func.parameters
)
from ...codegen import GpuKernel
sig_parts = [self._func_prefix] if self._func_prefix is not None else []
if isinstance(func, GpuKernel) and func.target == Target.CUDA:
sig_parts.append("__global__")
sig_parts += ["void", func.name, f"({params_str})"]
signature = " ".join(sig_parts)
return signature
@abstractmethod
def _symbol_decl(self, symb: PsSymbol) -> str:
pass
@abstractmethod
def _constant_literal(self, constant: PsConstant) -> str:
pass
@abstractmethod
def _type_str(self, dtype: PsType) -> str:
"""Return a valid string representation of the given type"""
def _char_and_op(self, node: PsBinOp) -> tuple[str, Ops]:
match node:
case PsAdd():
return ("+", Ops.Add)
case PsSub():
return ("-", Ops.Sub)
case PsMul():
return ("*", Ops.Mul)
case PsDiv() | PsIntDiv():
return ("/", Ops.Div)
case PsRem():
return ("%", Ops.Rem)
case PsLeftShift():
return ("<<", Ops.LeftShift)
case PsRightShift():
return (">>", Ops.RightShift)
case PsBitwiseAnd():
return ("&", Ops.BitwiseAnd)
case PsBitwiseXor():
return ("^", Ops.BitwiseXor)
case PsBitwiseOr():
return ("|", Ops.BitwiseOr)
case PsAnd():
return ("&&", Ops.LogicAnd)
case PsOr():
return ("||", Ops.LogicOr)
case PsEq():
return ("==", Ops.EqOp)
case PsNe():
return ("!=", Ops.EqOp)
case PsGt():
return (">", Ops.RelOp)
case PsGe():
return (">=", Ops.RelOp)
case PsLt():
return ("<", Ops.RelOp)
case PsLe():
return ("<=", Ops.RelOp)
case _:
assert False
from __future__ import annotations
from typing import TYPE_CHECKING
from pystencils.backend.ast.astnode import PsAstNode
from pystencils.backend.constants import PsConstant
from pystencils.backend.emission.base_printer import PrinterCtx, EmissionError
from pystencils.backend.memory import PsSymbol
from .base_printer import BasePrinter
from ...types import PsType, PsArrayType, PsScalarType, PsTypeError
from ..ast.expressions import PsBufferAcc
from ..ast.vector import PsVecMemAcc
if TYPE_CHECKING:
from ...codegen import Kernel
def emit_code(ast: PsAstNode | Kernel):
printer = CAstPrinter()
return printer(ast)
class CAstPrinter(BasePrinter):
def visit(self, node: PsAstNode, pc: PrinterCtx) -> str:
match node:
case PsVecMemAcc():
raise EmissionError(
f"Unable to print C code for vector memory access {node}.\n"
f"Vectorized memory accesses must be mapped to intrinsics before emission."
)
case PsBufferAcc():
raise EmissionError(
f"Unable to print C code for buffer access {node}.\n"
f"Buffer accesses must be lowered using the `LowerToC` pass before emission."
)
case _:
return super().visit(node, pc)
def _symbol_decl(self, symb: PsSymbol):
dtype = symb.get_dtype()
if isinstance(dtype, PsArrayType):
array_dims = dtype.shape
dtype = dtype.base_type
else:
array_dims = ()
code = f"{self._type_str(dtype)} {symb.name}"
for d in array_dims:
code += f"[{str(d) if d is not None else ''}]"
return code
def _constant_literal(self, constant: PsConstant):
dtype = constant.get_dtype()
if not isinstance(dtype, PsScalarType):
raise EmissionError("Cannot print literals for non-scalar constants.")
return dtype.create_literal(constant.value)
def _type_str(self, dtype: PsType):
try:
return dtype.c_string()
except PsTypeError:
raise EmissionError(f"Unable to print type {dtype} as a C data type.")
from __future__ import annotations
from typing import TYPE_CHECKING
from pystencils.backend.constants import PsConstant
from pystencils.backend.emission.base_printer import PrinterCtx
from pystencils.backend.memory import PsSymbol
from pystencils.types.meta import PsType, deconstify
from .base_printer import BasePrinter, Ops, LR
from ..ast import PsAstNode
from ..ast.expressions import PsBufferAcc
from ..ast.vector import PsVecMemAcc, PsVecBroadcast
if TYPE_CHECKING:
from ...codegen import Kernel
def emit_ir(ir: PsAstNode | Kernel):
"""Emit the IR as C-like pseudo-code for inspection."""
ir_printer = IRAstPrinter()
return ir_printer(ir)
class IRAstPrinter(BasePrinter):
"""Print the IR AST as pseudo-code.
This printer produces a complete pseudocode representation of a pystencils AST.
Other than the `CAstPrinter`, the `IRAstPrinter` is capable of emitting code for
each node defined in `ast <pystencils.backend.ast>`.
It is furthermore configurable w.r.t. the level of detail it should emit.
Args:
indent_width: Number of spaces with which to indent lines in each nested block.
annotate_constants: If ``True`` (the default), annotate all constant literals with their data type.
"""
def __init__(self, indent_width=3, annotate_constants: bool = True):
super().__init__(indent_width)
self._annotate_constants = annotate_constants
def visit(self, node: PsAstNode, pc: PrinterCtx) -> str:
match node:
case PsBufferAcc(ptr, indices):
pc.push_op(Ops.Subscript, LR.Left)
base_code = self.visit(ptr, pc)
pc.pop_op()
pc.push_op(Ops.Weakest, LR.Middle)
indices_code = ", ".join(self.visit(idx, pc) for idx in indices)
pc.pop_op()
return pc.parenthesize(
base_code + "[" + indices_code + "]", Ops.Subscript
)
case PsVecMemAcc(ptr, offset, lanes, stride):
pc.push_op(Ops.Subscript, LR.Left)
ptr_code = self.visit(ptr, pc)
pc.pop_op()
pc.push_op(Ops.Weakest, LR.Middle)
offset_code = self.visit(offset, pc)
pc.pop_op()
stride_code = "" if stride is None else f", stride={stride}"
code = f"vec_memacc< {lanes}{stride_code} >({ptr_code}, {offset_code})"
return pc.parenthesize(code, Ops.Subscript)
case PsVecBroadcast(lanes, operand):
pc.push_op(Ops.Weakest, LR.Middle)
operand_code = self.visit(operand, pc)
pc.pop_op()
return pc.parenthesize(
f"vec_broadcast<{lanes}>({operand_code})", Ops.Weakest
)
case _:
return super().visit(node, pc)
def _symbol_decl(self, symb: PsSymbol):
return f"{symb.name}: {self._type_str(symb.dtype)}"
def _constant_literal(self, constant: PsConstant) -> str:
if self._annotate_constants:
return f"[{constant.value}: {self._deconst_type_str(constant.dtype)}]"
else:
return str(constant.value)
def _type_str(self, dtype: PsType | None):
if dtype is None:
return "<untyped>"
else:
return str(dtype)
def _deconst_type_str(self, dtype: PsType | None):
if dtype is None:
return "<untyped>"
else:
return str(deconstify(dtype))
"""Errors and Exceptions raised by the backend during kernel translation."""
class PsInternalCompilerError(Exception):
"""Indicates an internal error during kernel translation, most likely due to a bug inside pystencils."""
class PsInputError(Exception):
"""Indicates unsupported user input to the translation system"""
class KernelConstraintsError(Exception):
"""Indicates a constraint violation in the symbolic kernel"""
class FreezeError(Exception):
"""Signifies an error during expression freezing."""
class TypificationError(Exception):
"""Indicates a fatal error during typification."""
class VectorizationError(Exception):
"""Indicates an error during a vectorization procedure"""
class MaterializationError(Exception):
"""Indicates a fatal error during materialization of any abstract kernel component."""
"""
The module `pystencils.backend.extensions` contains extensions to the pystencils code generator
beyond its core functionality.
The tools and classes of this module are considered experimental;
their support by the remaining code generator is limited.
They can be used to model and generate code outside of the usual scope of pystencils,
such as non-standard syntax and types.
At the moment, the primary use case is the modelling of C++ syntax.
Foreign Syntax Support
======================
.. automodule:: pystencils.backend.extensions.foreign_ast
:members:
C++ Language Support
====================
.. automodule:: pystencils.backend.extensions.cpp
:members:
"""
from .foreign_ast import PsForeignExpression
__all__ = ["PsForeignExpression"]
from __future__ import annotations
from typing import Iterable, cast
from pystencils.backend.ast.astnode import PsAstNode
from ..ast.expressions import PsExpression
from .foreign_ast import PsForeignExpression
from ...types import PsType
class CppMethodCall(PsForeignExpression):
"""C++ method call on an expression."""
def __init__(
self, obj: PsExpression, method: str, return_type: PsType, args: Iterable = ()
):
self._method = method
self._return_type = return_type
children = [obj] + list(args)
super().__init__(children, return_type)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, CppMethodCall):
return False
return super().structurally_equal(other) and self._method == other._method
def _clone_expr(self) -> CppMethodCall:
return CppMethodCall(
cast(PsExpression, self.children[0]),
self._method,
self._return_type,
self.children[1:],
)
def get_code(self, children_code: Iterable[str]) -> str:
cs = list(children_code)
obj_code = cs[0]
args_code = cs[1:]
args = ", ".join(args_code)
return f"({obj_code}).{self._method}({args})"
from __future__ import annotations
from typing import Iterable
from abc import ABC, abstractmethod
from pystencils.backend.ast.astnode import PsAstNode
from ..ast.expressions import PsExpression
from ..ast.util import failing_cast
from ...types import PsType
class PsForeignExpression(PsExpression, ABC):
"""Base class for foreign expressions.
Foreign expressions are expressions whose properties are not modelled by the pystencils AST,
and which pystencils therefore does not understand.
There are many situations where non-supported expressions are needed;
the most common use case is C++ syntax.
Support for foreign expressions by the code generator is therefore very limited;
as a rule of thumb, only printing is supported.
Type checking and most transformations will fail when encountering a `PsForeignExpression`.
"""
__match_args__ = ("children",)
def __init__(self, children: Iterable[PsExpression], dtype: PsType | None = None):
self._children = list(children)
super().__init__(dtype)
@abstractmethod
def get_code(self, children_code: Iterable[str]) -> str:
"""Print this expression, with the given code for each of its children."""
pass
def get_children(self) -> tuple[PsAstNode, ...]:
return tuple(self._children)
def set_child(self, idx: int, c: PsAstNode):
self._children[idx] = failing_cast(PsExpression, c)
def __repr__(self) -> str:
return f"{type(self)}({self._children})"
"""
Functions supported by pystencils.
Every supported function might require handling logic in the following modules:
- In `freeze.FreezeExpressions`, a case in `map_Function` or a separate mapper method to catch its frontend variant
- In each backend platform, a case in `Platform.select_function` to map the function onto a concrete
C/C++ implementation
- If very special typing rules apply, a case in `typification.Typifier`.
In most cases, typification of function applications will require no special handling.
.. autoclass:: PsFunction
:members:
.. autoclass:: MathFunctions
:members:
:undoc-members:
.. autoclass:: PsMathFunction
:members:
.. autoclass:: CFunction
:members:
"""
from __future__ import annotations
from typing import Any, Sequence, TYPE_CHECKING
from abc import ABC
from enum import Enum
from ..types import PsType
from .exceptions import PsInternalCompilerError
if TYPE_CHECKING:
from .ast.expressions import PsExpression
class PsFunction(ABC):
"""Base class for functions occuring in the IR"""
__match_args__ = ("name", "arg_count")
def __init__(self, name: str, num_args: int):
self._name = name
self._num_args = num_args
@property
def name(self) -> str:
"""Name of this function."""
return self._name
@property
def arg_count(self) -> int:
"Number of arguments this function takes"
return self._num_args
def __call__(self, *args: PsExpression) -> Any:
from .ast.expressions import PsCall
return PsCall(self, args)
class MathFunctions(Enum):
"""Mathematical functions supported by the backend.
Each platform has to materialize these functions to a concrete implementation.
"""
Exp = ("exp", 1)
Log = ("log", 1)
Sin = ("sin", 1)
Cos = ("cos", 1)
Tan = ("tan", 1)
Sinh = ("sinh", 1)
Cosh = ("cosh", 1)
ASin = ("asin", 1)
ACos = ("acos", 1)
ATan = ("atan", 1)
Sqrt = ("sqrt", 1)
Abs = ("abs", 1)
Floor = ("floor", 1)
Ceil = ("ceil", 1)
Min = ("min", 2)
Max = ("max", 2)
Pow = ("pow", 2)
ATan2 = ("atan2", 2)
def __init__(self, func_name, num_args):
self.function_name = func_name
self.num_args = num_args
class PsMathFunction(PsFunction):
"""Homogenously typed mathematical functions."""
__match_args__ = ("func",)
def __init__(self, func: MathFunctions) -> None:
super().__init__(func.function_name, func.num_args)
self._func = func
@property
def func(self) -> MathFunctions:
return self._func
def __str__(self) -> str:
return f"{self._func.function_name}"
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsMathFunction):
return False
return self._func == other._func
def __hash__(self) -> int:
return hash(self._func)
class CFunction(PsFunction):
"""A concrete C function.
Instances of this class represent a C function by its name, parameter types, and return type.
Args:
name: Function name
param_types: Types of the function parameters
return_type: The function's return type
"""
__match_args__ = ("name", "parameter_types", "return_type")
@staticmethod
def parse(obj) -> CFunction:
"""Parse the signature of a Python callable object to obtain a CFunction object.
The callable must be fully annotated with type-like objects convertible by `create_type`.
"""
import inspect
from pystencils.types import create_type
if not inspect.isfunction(obj):
raise PsInternalCompilerError(f"Cannot parse object {obj} as a function")
func_sig = inspect.signature(obj)
func_name = obj.__name__
arg_types = [
create_type(param.annotation) for param in func_sig.parameters.values()
]
ret_type = create_type(func_sig.return_annotation)
return CFunction(func_name, arg_types, ret_type)
def __init__(self, name: str, param_types: Sequence[PsType], return_type: PsType):
super().__init__(name, len(param_types))
self._param_types = tuple(param_types)
self._return_type = return_type
@property
def parameter_types(self) -> tuple[PsType, ...]:
return self._param_types
@property
def return_type(self) -> PsType:
return self._return_type
def __str__(self) -> str:
param_types = ", ".join(str(t) for t in self._param_types)
return f"{self._return_type} {self._name}({param_types})"
def __repr__(self) -> str:
return f"CFunction({self._name}, {self._param_types}, {self._return_type})"
from .context import KernelCreationContext
from .analysis import KernelAnalysis
from .freeze import FreezeExpressions
from .typification import Typifier
from .ast_factory import AstFactory
from .iteration_space import (
IterationSpace,
FullIterationSpace,
SparseIterationSpace,
create_full_iteration_space,
create_sparse_iteration_space,
)
__all__ = [
"KernelCreationContext",
"KernelAnalysis",
"FreezeExpressions",
"Typifier",
"AstFactory",
"IterationSpace",
"FullIterationSpace",
"SparseIterationSpace",
"create_full_iteration_space",
"create_sparse_iteration_space",
]
from __future__ import annotations
from collections import namedtuple, defaultdict
from typing import Any, Sequence
from itertools import chain
import sympy as sp
from .context import KernelCreationContext
from ...field import Field
from ...simp import AssignmentCollection
from sympy.codegen.ast import AssignmentBase
from ..exceptions import PsInternalCompilerError, KernelConstraintsError
class KernelAnalysis:
"""General analysis pass over a kernel expressed using the SymPy frontend.
The kernel analysis fulfills two tasks. It checks the SymPy input for consistency,
and populates the context with required knowledge about the kernel.
A `KernelAnalysis` object may be called at most once.
**Consistency and Constraints**
The following checks are performed:
- **SSA Form:** The given assignments must be in single-assignment form; each symbol must be written at most once.
- **Independence of Accesses:** To avoid loop-carried dependencies, each field may be written at most once at
each index, and if a field is written at some location with index ``i``, it may only be read with index ``i`` in
the same location.
- **Independence of Writes:** A weaker requirement than access independence; each field may only be written once
at each index.
- **Dimension of index fields:** Index fields occuring in the kernel must have exactly one spatial dimension.
**Knowledge Collection**
The following knowledge is collected into the context:
- The set of fields accessed in the kernel
"""
FieldAndIndex = namedtuple("FieldAndIndex", ["field", "index"])
def __init__(
self,
ctx: KernelCreationContext,
check_access_independence: bool = True,
check_double_writes: bool = True,
):
self._ctx = ctx
self._check_access_independence = check_access_independence
self._check_double_writes = check_double_writes
# Map pairs of fields and indices to offsets
self._field_writes: dict[KernelAnalysis.FieldAndIndex, set[Any]] = defaultdict(
set
)
self._fields_written: set[Field] = set()
self._fields_read: set[Field] = set()
self._scopes = NestedScopes()
self._called = False
def __call__(
self, obj: AssignmentCollection | Sequence[AssignmentBase] | AssignmentBase
):
if self._called:
raise PsInternalCompilerError("KernelAnalysis called twice!")
self._called = True
self._visit(obj)
for field in chain(self._fields_written, self._fields_read):
self._ctx.add_field(field)
def _visit(self, obj: Any):
match obj:
case AssignmentCollection(main_asms, subexps):
self._visit(subexps)
self._visit(main_asms)
case [*asms]: # lists and tuples are unpacked
for asm in asms:
self._visit(asm)
case AssignmentBase():
self._handle_rhs(obj.rhs)
self._handle_lhs(obj.lhs)
case unknown:
raise KernelConstraintsError(
f"Don't know how to interpret {unknown} in a kernel."
)
def _handle_lhs(self, lhs: sp.Basic):
if not isinstance(lhs, sp.Symbol):
raise KernelConstraintsError(
f"Invalid expression on assignment left-hand side: {lhs}"
)
match lhs:
case Field.Access(field, offsets, index):
self._fields_written.add(field)
self._fields_read.update(lhs.indirect_addressing_fields)
fai = self.FieldAndIndex(field, index)
if self._check_double_writes and offsets in self._field_writes[fai]:
raise KernelConstraintsError(
f"Field {field.name} is written twice at the same location"
)
self._field_writes[fai].add(offsets)
if self._check_double_writes and len(self._field_writes[fai]) > 1:
raise KernelConstraintsError(
f"Field {field.name} is written at two different locations"
)
case sp.Symbol():
if self._scopes.is_defined_locally(lhs):
raise KernelConstraintsError(
f"Assignments not in SSA form, multiple assignments to {lhs.name}"
)
if lhs in self._scopes.free_parameters:
raise KernelConstraintsError(
f"Symbol {lhs.name} is written, after it has been read"
)
self._scopes.define_symbol(lhs)
def _handle_rhs(self, rhs: sp.Basic):
def rec(expr: sp.Basic):
match expr:
case Field.Access(field, offsets, index):
self._fields_read.add(field)
self._fields_read.update(expr.indirect_addressing_fields)
# TODO: Should we recurse into the arguments of the field access?
if self._check_access_independence:
writes = self._field_writes[
KernelAnalysis.FieldAndIndex(field, index)
]
assert len(writes) <= 1
for write_offset in writes:
if write_offset != offsets:
raise KernelConstraintsError(
f"Violation of loop independence condition. Field "
f"{field} is read at {offsets} and written at {write_offset}"
)
case sp.Symbol():
self._scopes.access_symbol(expr)
for arg in expr.args:
rec(arg)
rec(rhs)
class NestedScopes:
"""Symbol visibility model using nested scopes
- every accessed symbol that was not defined before, is added as a "free parameter"
- free parameters are global, i.e. they are not in scopes
- push/pop adds or removes a scope
>>> s = NestedScopes()
>>> s.access_symbol("a")
>>> s.is_defined("a")
False
>>> s.free_parameters
{'a'}
>>> s.define_symbol("b")
>>> s.is_defined("b")
True
>>> s.push()
>>> s.is_defined_locally("b")
False
>>> s.define_symbol("c")
>>> s.pop()
>>> s.is_defined("c")
False
"""
def __init__(self):
self.free_parameters = set()
self._defined = [set()]
def access_symbol(self, symbol):
if not self.is_defined(symbol):
self.free_parameters.add(symbol)
def define_symbol(self, symbol):
self._defined[-1].add(symbol)
def is_defined(self, symbol):
return any(symbol in scopes for scopes in self._defined)
def is_defined_locally(self, symbol):
return symbol in self._defined[-1]
def push(self):
self._defined.append(set())
def pop(self):
self._defined.pop()
assert self.depth >= 1
@property
def depth(self):
return len(self._defined)
from typing import Any, Sequence, cast, overload
import numpy as np
import sympy as sp
from sympy.codegen.ast import AssignmentBase
from ..ast import PsAstNode
from ..ast.expressions import PsExpression, PsSymbolExpr, PsConstantExpr
from ..ast.structural import PsLoop, PsBlock, PsAssignment
from ..memory import PsSymbol
from ..constants import PsConstant
from .context import KernelCreationContext
from .freeze import FreezeExpressions, ExprLike
from .typification import Typifier
from .iteration_space import FullIterationSpace
IndexParsable = PsExpression | PsSymbol | PsConstant | sp.Expr | int | np.integer
_IndexParsable = (PsExpression, PsSymbol, PsConstant, sp.Expr, int, np.integer)
class AstFactory:
"""Factory providing a convenient interface for building syntax trees.
The `AstFactory` uses the defaults provided by the given `KernelCreationContext` to quickly create
AST nodes. Depending on context (numerical, loop indexing, etc.), symbols and constants receive either
`ctx.default_dtype <KernelCreationContext.default_dtype>` or `ctx.index_dtype <KernelCreationContext.index_dtype>`.
Args:
ctx: The kernel creation context
"""
def __init__(self, ctx: KernelCreationContext):
self._ctx = ctx
self._freeze = FreezeExpressions(ctx)
self._typify = Typifier(ctx)
@overload
def parse_sympy(self, sp_obj: sp.Symbol) -> PsSymbolExpr:
pass
@overload
def parse_sympy(self, sp_obj: ExprLike) -> PsExpression:
pass
@overload
def parse_sympy(self, sp_obj: AssignmentBase) -> PsAssignment:
pass
def parse_sympy(self, sp_obj: ExprLike | AssignmentBase) -> PsAstNode:
"""Parse a SymPy expression or assignment through `FreezeExpressions` and `Typifier`.
The expression or assignment will be typified in a numerical context, using the kernel
creation context's `default_dtype <KernelCreationContext.default_dtype>`.
Args:
sp_obj: A SymPy expression or assignment
"""
return self._typify(self._freeze(sp_obj))
@overload
def parse_index(self, idx: sp.Symbol | PsSymbol | PsSymbolExpr) -> PsSymbolExpr:
pass
@overload
def parse_index(
self, idx: int | np.integer | PsConstant | PsConstantExpr
) -> PsConstantExpr:
pass
@overload
def parse_index(self, idx: sp.Expr | PsExpression) -> PsExpression:
pass
def parse_index(self, idx: IndexParsable):
"""Parse the given object as an expression with data type
`ctx.index_dtype <KernelCreationContext.index_dtype>`."""
if not isinstance(idx, _IndexParsable):
raise TypeError(
f"Cannot parse object of type {type(idx)} as an index expression"
)
match idx:
case PsExpression():
return self._typify.typify_expression(idx, self._ctx.index_dtype)[0]
case PsSymbol() | PsConstant():
return self._typify.typify_expression(
PsExpression.make(idx), self._ctx.index_dtype
)[0]
case sp.Expr():
return self._typify.typify_expression(
self._freeze(idx), self._ctx.index_dtype
)[0]
case _:
return PsExpression.make(PsConstant(idx, self._ctx.index_dtype))
def _parse_any_index(self, idx: Any) -> PsExpression:
if not isinstance(idx, _IndexParsable):
raise TypeError(f"Cannot parse {idx} as an index expression")
return self.parse_index(idx)
def parse_slice(
self,
iter_slice: IndexParsable | slice,
normalize_to: IndexParsable | None = None,
) -> tuple[PsExpression, PsExpression, PsExpression]:
"""Parse a slice to obtain start, stop and step expressions for a loop or iteration space dimension.
The slice entries may be instances of `PsExpression`, `PsSymbol` or `PsConstant`, in which case they
must typify with the kernel creation context's ``index_dtype``.
They may also be sympy expressions or integer constants, in which case they are parsed to AST objects
and must also typify with the kernel creation context's ``index_dtype``.
The `step` member of the slice, if it is constant, must be positive.
The slice may optionally be normalized with respect to an upper iteration limit.
If ``normalize_to`` is specified, negative integers in ``iter_slice.start`` and ``iter_slice.stop`` will
be added to that normalization limit.
Args:
iter_slice: The iteration slice
normalize_to: The upper iteration limit with respect to which the slice should be normalized
"""
from pystencils.backend.transformations import EliminateConstants
fold = EliminateConstants(self._ctx)
start: PsExpression
stop: PsExpression | None
step: PsExpression
if not isinstance(iter_slice, slice):
start = self.parse_index(iter_slice)
stop = fold(
self._typify(self.parse_index(iter_slice) + self.parse_index(1))
)
step = self.parse_index(1)
if normalize_to is not None:
upper_limit = self.parse_index(normalize_to)
if isinstance(start, PsConstantExpr) and start.constant.value < 0:
start = fold(self._typify(upper_limit.clone() + start))
stop = fold(self._typify(upper_limit.clone() + stop))
else:
start = self._parse_any_index(
iter_slice.start if iter_slice.start is not None else 0
)
stop = (
self._parse_any_index(iter_slice.stop)
if iter_slice.stop is not None
else None
)
step = self._parse_any_index(
iter_slice.step if iter_slice.step is not None else 1
)
if isinstance(step, PsConstantExpr) and step.constant.value <= 0:
raise ValueError(
f"Invalid value for `slice.step`: {step.constant.value}"
)
if normalize_to is not None:
upper_limit = self.parse_index(normalize_to)
if isinstance(start, PsConstantExpr) and start.constant.value < 0:
start = fold(self._typify(upper_limit.clone() + start))
if stop is None:
stop = upper_limit
elif isinstance(stop, PsConstantExpr) and stop.constant.value < 0:
stop = fold(self._typify(upper_limit.clone() + stop))
elif stop is None:
raise ValueError(
"Cannot parse a slice with `stop == None` if no normalization limit is given"
)
assert stop is not None # for mypy
return start, stop, step
def loop(self, ctr_name: str, iteration_slice: slice, body: PsBlock):
"""Create a loop from a slice.
Args:
ctr_name: Name of the loop counter
iteration_slice: The iteration region as a slice; see `parse_slice`.
body: The loop body
"""
ctr = PsExpression.make(self._ctx.get_symbol(ctr_name, self._ctx.index_dtype))
start, stop, step = self.parse_slice(iteration_slice)
return PsLoop(
ctr,
start,
stop,
step,
body,
)
def loop_nest(
self, counters: Sequence[str], slices: Sequence[slice], body: PsBlock
) -> PsLoop:
"""Create a loop nest from a sequence of slices.
**Example:**
This snippet creates a 3D loop nest with ten iterations in each dimension::
>>> from pystencils import make_slice
>>> ctx = KernelCreationContext()
>>> factory = AstFactory(ctx)
>>> loop = factory.loop_nest(("i", "j", "k"), make_slice[:10,:10,:10], PsBlock([]))
Args:
counters: Sequence of names for the loop counters
slices: Sequence of iteration slices; see also `parse_slice`
body: The loop body
"""
if not slices:
raise ValueError(
"At least one slice must be specified to create a loop nest."
)
ast = body
for ctr_name, sl in zip(counters[::-1], slices[::-1], strict=True):
ast = self.loop(
ctr_name,
sl,
PsBlock([ast]) if not isinstance(ast, PsBlock) else ast,
)
return cast(PsLoop, ast)
def loops_from_ispace(
self,
ispace: FullIterationSpace,
body: PsBlock,
loop_order: Sequence[int] | None = None,
) -> PsLoop:
"""Create a loop nest from a dense iteration space.
Args:
ispace: The iteration space object
body: The loop body
loop_order: Optionally, a permutation of integers indicating the order of loops
"""
dimensions = ispace.dimensions
if loop_order is not None:
dimensions = [dimensions[coordinate] for coordinate in loop_order]
outer_node: PsLoop | PsBlock = body
for dimension in dimensions[::-1]:
outer_node = PsLoop(
PsSymbolExpr(dimension.counter),
dimension.start,
dimension.stop,
dimension.step,
(
outer_node
if isinstance(outer_node, PsBlock)
else PsBlock([outer_node])
),
)
assert isinstance(outer_node, PsLoop)
return outer_node