Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • const_fix
  • fhennig/macos-and-arm
  • fhennig/random-numbers
  • fhennig/runtime-integration
  • fhennig/v2.0-deprecations
  • fma
  • gpu_bufferfield_fix
  • gpu_liveness_opts
  • holzer-master-patch-46757
  • hyteg
  • improved_comm
  • master
  • target_dh_refactoring
  • v2.0-dev
  • vectorization_sqrt_fix
  • zikeliml/124-rework-tutorials
  • zikeliml/Task-96-dotExporterForAST
  • last/Kerncraft
  • last/LLVM
  • last/OpenCL
  • release/0.2.1
  • release/0.2.10
  • release/0.2.11
  • release/0.2.12
  • release/0.2.13
  • release/0.2.14
  • release/0.2.15
  • release/0.2.2
  • release/0.2.3
  • release/0.2.4
  • release/0.2.6
  • release/0.2.7
  • release/0.2.8
  • release/0.2.9
  • release/0.3.0
  • release/0.3.1
  • release/0.3.2
  • release/0.3.3
  • release/0.3.4
  • release/0.4.0
  • release/0.4.1
  • release/0.4.2
  • release/0.4.3
  • release/0.4.4
  • release/1.0
  • release/1.0.1
  • release/1.1
  • release/1.1.1
  • release/1.2
  • release/1.3
  • release/1.3.1
  • release/1.3.2
  • release/1.3.3
  • release/1.3.4
  • release/1.3.5
  • release/1.3.6
  • release/1.3.7
  • release/2.0.dev0
59 results

Target

Select target project
  • anirudh.jonnalagadda/pystencils
  • hyteg/pystencils
  • jbadwaik/pystencils
  • jngrad/pystencils
  • itischler/pystencils
  • ob28imeq/pystencils
  • hoenig/pystencils
  • Bindgen/pystencils
  • hammer/pystencils
  • da15siwa/pystencils
  • holzer/pystencils
  • alexander.reinauer/pystencils
  • ec93ujoh/pystencils
  • Harke/pystencils
  • seitz/pystencils
  • pycodegen/pystencils
16 results
Select Git revision
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • const_fix
  • fhennig/macos-and-arm
  • fhennig/random-numbers
  • fhennig/runtime-integration
  • fhennig/v2.0-deprecations
  • fma
  • gpu_bufferfield_fix
  • gpu_liveness_opts
  • holzer-master-patch-46757
  • hyteg
  • improved_comm
  • master
  • target_dh_refactoring
  • v2.0-dev
  • vectorization_sqrt_fix
  • zikeliml/124-rework-tutorials
  • zikeliml/Task-96-dotExporterForAST
  • last/Kerncraft
  • last/LLVM
  • last/OpenCL
  • release/0.2.1
  • release/0.2.10
  • release/0.2.11
  • release/0.2.12
  • release/0.2.13
  • release/0.2.14
  • release/0.2.15
  • release/0.2.2
  • release/0.2.3
  • release/0.2.4
  • release/0.2.6
  • release/0.2.7
  • release/0.2.8
  • release/0.2.9
  • release/0.3.0
  • release/0.3.1
  • release/0.3.2
  • release/0.3.3
  • release/0.3.4
  • release/0.4.0
  • release/0.4.1
  • release/0.4.2
  • release/0.4.3
  • release/0.4.4
  • release/1.0
  • release/1.0.1
  • release/1.1
  • release/1.1.1
  • release/1.2
  • release/1.3
  • release/1.3.1
  • release/1.3.2
  • release/1.3.3
  • release/1.3.4
  • release/1.3.5
  • release/1.3.6
  • release/1.3.7
  • release/2.0.dev0
59 results
Show changes
Showing
with 5242 additions and 0 deletions
from __future__ import annotations
from typing import Iterable, Iterator, Any
from itertools import chain, count
from collections import namedtuple, defaultdict
import re
from ...defaults import DEFAULTS
from ...field import Field, FieldType
from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType
from ..memory import PsSymbol, PsBuffer
from ..constants import PsConstant
from ...types import (
PsType,
PsIntegerType,
PsNumericType,
PsPointerType,
deconstify,
)
from ..exceptions import PsInternalCompilerError, KernelConstraintsError
from .iteration_space import IterationSpace, FullIterationSpace, SparseIterationSpace
class FieldsInKernel:
def __init__(self) -> None:
self.domain_fields: set[Field] = set()
self.index_fields: set[Field] = set()
self.custom_fields: set[Field] = set()
self.buffer_fields: set[Field] = set()
self.archetype_field: Field | None = None
def __iter__(self) -> Iterator:
return chain(
self.domain_fields,
self.index_fields,
self.custom_fields,
self.buffer_fields,
)
FieldArrayPair = namedtuple("FieldArrayPair", ("field", "array"))
class KernelCreationContext:
"""Manages the translation process from the SymPy frontend to the backend AST, and collects
all necessary information for the translation:
- *Data Types*: The kernel creation context manages the default data types for loop limits
and counters, index calculations, and the typifier.
- *Symbols*: The context maintains a symbol table, keeping track of all symbols encountered
during kernel translation together with their types.
- *Fields and Arrays*: The context collects all fields encountered during code generation,
applies a few consistency checks to them, and manages their associated arrays.
- *Iteration Space*: The context manages the iteration space of the kernel currently being
translated.
- *Constraints*: The context collects all kernel parameter constraints introduced during the
translation process.
- *Required Headers*: The context collects all header files required for the kernel to run.
"""
def __init__(
self,
default_dtype: PsNumericType = DEFAULTS.numeric_dtype,
index_dtype: PsIntegerType = DEFAULTS.index_dtype,
):
self._default_dtype = deconstify(default_dtype)
self._index_dtype = deconstify(index_dtype)
self._symbols: dict[str, PsSymbol] = dict()
self._symbol_ctr_pattern = re.compile(r"__[0-9]+$")
self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0)
self._fields_and_arrays: dict[str, FieldArrayPair] = dict()
self._fields_collection = FieldsInKernel()
self._ispace: IterationSpace | None = None
self._req_headers: set[str] = set()
self._metadata: dict[str, Any] = dict()
@property
def default_dtype(self) -> PsNumericType:
"""Data type used by default for numerical expressions"""
return self._default_dtype
@property
def index_dtype(self) -> PsIntegerType:
"""Data type used by default for index expressions"""
return self._index_dtype
def resolve_dynamic_type(self, dtype: DynamicType | PsType) -> PsType:
"""Selects the appropriate data type for `DynamicType` instances, and returns all other types as they are."""
match dtype:
case DynamicType.NUMERIC_TYPE:
return self._default_dtype
case DynamicType.INDEX_TYPE:
return self._index_dtype
case _:
return dtype
@property
def metadata(self) -> dict[str, Any]:
return self._metadata
# Symbols
def get_symbol(self, name: str, dtype: PsType | None = None) -> PsSymbol:
"""Retrieve the symbol with the given name and data type from the symbol table.
If no symbol named ``name`` exists, a new symbol with the given data type is created.
If a symbol with the given ``name`` already exists and ``dtype`` is not `None`,
the given data type will be applied to it, and it is returned.
If the symbol already has a different data type, an error will be raised.
If the symbol already exists and ``dtype`` is `None`, the existing symbol is returned
without checking or altering its data type.
Args:
name: The symbol's name
dtype: The symbol's data type, or `None`
"""
if name not in self._symbols:
symb = PsSymbol(name, None)
self._symbols[name] = symb
else:
symb = self._symbols[name]
if dtype is not None:
symb.apply_dtype(dtype)
return symb
def get_new_symbol(self, name: str, dtype: PsType | None = None) -> PsSymbol:
"""Always create a new symbol, deduplicating its name if another symbol with the same name already exists."""
if name in self._symbols:
return self.duplicate_symbol(self._symbols[name], dtype)
else:
return self.get_symbol(name, dtype)
def find_symbol(self, name: str) -> PsSymbol | None:
"""Find a symbol with the given name in the symbol table, if it exists.
Returns:
The symbol with the given name, or `None` if no such symbol exists.
"""
return self._symbols.get(name, None)
def add_symbol(self, symbol: PsSymbol):
"""Add an existing symbol to the symbol table.
If a symbol with the same name already exists, an error will be raised.
"""
if symbol.name in self._symbols:
raise PsInternalCompilerError(f"Duplicate symbol: {symbol.name}")
self._symbols[symbol.name] = symbol
def replace_symbol(self, old: PsSymbol, new: PsSymbol):
"""Replace one symbol by another.
The two symbols ``old`` and ``new`` must have the same name, but may have different data types.
"""
if old.name != new.name:
raise PsInternalCompilerError(
"replace_symbol: Old and new symbol must have the same name"
)
if old.name not in self._symbols:
raise PsInternalCompilerError("Trying to replace an unknown symbol")
self._symbols[old.name] = new
def duplicate_symbol(
self, symb: PsSymbol, new_dtype: PsType | None = None
) -> PsSymbol:
"""Canonically duplicates the given symbol.
A new symbol with the new name ``symb.name + "__<counter>"`` and optionally a different data type
is created, added to the symbol table, and returned.
The ``counter`` reflects the number of previously created duplicates of this symbol.
"""
basename = self.basename(symb)
if new_dtype is None:
new_dtype = symb.dtype
initial_count = self._symbol_dup_table[basename]
for i in count(initial_count):
dup_name = f"{basename}__{i}"
if self.find_symbol(dup_name) is None:
self._symbol_dup_table[basename] = i + 1
return self.get_symbol(dup_name, new_dtype)
assert False, "unreachable code"
def basename(self, symb: PsSymbol) -> str:
"""Returns the original name a symbol had before duplication."""
if (result := self._symbol_ctr_pattern.search(symb.name)) is not None:
span = result.span()
return symb.name[: span[0]]
else:
return symb.name
@property
def symbols(self) -> Iterable[PsSymbol]:
"""Return an iterable of all symbols listed in the symbol table."""
return self._symbols.values()
# Fields and Arrays
@property
def fields(self) -> FieldsInKernel:
"""Collection of fields that occured during the current kernel translation."""
return self._fields_collection
def add_field(self, field: Field):
"""Add the given field to the context's fields collection.
This method adds the passed ``field`` to the context's field collection, which is
accesible through the `field <KernelCreationContext.fields>` member,
and creates the underlying buffer for the field
which is retrievable through `get_buffer`.
Before adding the field to the collection, various sanity and constraint checks are applied.
"""
if field.name in self._fields_and_arrays:
existing_field = self._fields_and_arrays[field.name].field
if existing_field != field:
raise KernelConstraintsError(
"Encountered two fields with the same name, but different properties: "
f"{field} and {existing_field}"
)
else:
return
# Check field constraints, create buffer, and add them to the collection
match field.field_type:
case FieldType.GENERIC | FieldType.STAGGERED | FieldType.STAGGERED_FLUX:
buf = self._create_regular_field_buffer(field)
self._fields_collection.domain_fields.add(field)
case FieldType.BUFFER:
buf = self._create_buffer_field_buffer(field)
self._fields_collection.buffer_fields.add(field)
case FieldType.INDEXED:
if field.spatial_dimensions != 1:
raise KernelConstraintsError(
f"Invalid spatial shape of index field {field.name}: {field.spatial_dimensions}. "
"Index fields must be one-dimensional."
)
buf = self._create_regular_field_buffer(field)
self._fields_collection.index_fields.add(field)
case FieldType.CUSTOM:
buf = self._create_regular_field_buffer(field)
self._fields_collection.custom_fields.add(field)
case _:
assert False, "unreachable code"
self._fields_and_arrays[field.name] = FieldArrayPair(field, buf)
@property
def arrays(self) -> Iterable[PsBuffer]:
# return self._fields_and_arrays.values()
yield from (item.array for item in self._fields_and_arrays.values())
def get_buffer(self, field: Field) -> PsBuffer:
"""Retrieve the underlying array for a given field.
If the given field was not previously registered using `add_field`,
this method internally calls `add_field` to check the field for consistency.
"""
if field.name in self._fields_and_arrays:
if field != self._fields_and_arrays[field.name].field:
raise KernelConstraintsError(
"Encountered two fields of the same name but with different properties."
)
else:
self.add_field(field)
return self._fields_and_arrays[field.name].array
def find_field(self, name: str) -> Field:
return self._fields_and_arrays[name].field
# Iteration Space
def set_iteration_space(self, ispace: IterationSpace):
"""Set the iteration space used for the current kernel."""
self._ispace = ispace
def get_iteration_space(self) -> IterationSpace:
if self._ispace is None:
raise PsInternalCompilerError("No iteration space set in context.")
return self._ispace
def get_full_iteration_space(self) -> FullIterationSpace:
if not isinstance(self._ispace, FullIterationSpace):
raise PsInternalCompilerError("No full iteration space set in context.")
return self._ispace
def get_sparse_iteration_space(self) -> SparseIterationSpace:
if not isinstance(self._ispace, SparseIterationSpace):
raise PsInternalCompilerError("No sparse iteration space set in context.")
return self._ispace
# Headers
@property
def required_headers(self) -> set[str]:
return self._req_headers
def require_header(self, header: str):
self._req_headers.add(header)
# ----------- Internals ---------------------------------------------------------------------
def _normalize_type(self, s: TypedSymbol) -> PsIntegerType:
match s.dtype:
case DynamicType.INDEX_TYPE:
return self.index_dtype
case DynamicType.NUMERIC_TYPE:
if isinstance(self.default_dtype, PsIntegerType):
return self.default_dtype
else:
raise KernelConstraintsError(
f"Cannot use non-integer default numeric type {self.default_dtype} "
f"in field indexing symbol {s}."
)
case PsIntegerType():
return deconstify(s.dtype)
case _:
raise KernelConstraintsError(
f"Invalid data type for field indexing symbol {s}: {s.dtype}"
)
def _create_regular_field_buffer(self, field: Field) -> PsBuffer:
idx_types = set(
self._normalize_type(s)
for s in chain(field.shape, field.strides)
if isinstance(s, TypedSymbol)
)
entry_type = self.resolve_dynamic_type(field.dtype)
if len(idx_types) > 1:
raise KernelConstraintsError(
f"Multiple incompatible types found in index symbols of field {field}: "
f"{idx_types}"
)
idx_type = idx_types.pop() if len(idx_types) > 0 else self.index_dtype
def convert_size(s: TypedSymbol | int) -> PsSymbol | PsConstant:
if isinstance(s, TypedSymbol):
return self.get_symbol(s.name, idx_type)
else:
return PsConstant(s, idx_type)
buf_shape = [convert_size(s) for s in field.shape]
buf_strides = [convert_size(s) for s in field.strides]
# The frontend doesn't quite agree with itself on how to model
# fields with trivial index dimensions. Sometimes the index_shape is empty,
# sometimes its (1,). This is canonicalized here.
if not field.index_shape:
buf_shape += [convert_size(1)]
buf_strides += [convert_size(1)]
from ...codegen.properties import FieldShape, FieldStride
for i, size in enumerate(buf_shape):
if isinstance(size, PsSymbol):
size.add_property(FieldShape(field, i))
for i, stride in enumerate(buf_strides):
if isinstance(stride, PsSymbol):
stride.add_property(FieldStride(field, i))
base_ptr = self.get_symbol(
DEFAULTS.field_pointer_name(field.name),
PsPointerType(entry_type, restrict=True),
)
return PsBuffer(field.name, entry_type, base_ptr, buf_shape, buf_strides)
def _create_buffer_field_buffer(self, field: Field) -> PsBuffer:
if field.spatial_dimensions != 1:
raise KernelConstraintsError(
f"Invalid spatial shape of buffer field {field.name}: {field.spatial_dimensions}. "
"Buffer fields must be one-dimensional."
)
if field.index_dimensions > 1:
raise KernelConstraintsError(
f"Invalid index shape of buffer field {field.name}: {field.spatial_dimensions}. "
"Buffer fields can have at most one index dimension."
)
num_entries = field.index_shape[0] if field.index_shape else 1
if not isinstance(num_entries, int):
raise KernelConstraintsError(
f"Invalid index shape of buffer field {field.name}: {num_entries}. "
"Buffer fields cannot have variable index shape."
)
buffer_len = field.spatial_shape[0]
buf_shape: list[PsSymbol | PsConstant]
if isinstance(buffer_len, TypedSymbol):
from ...codegen.properties import FieldShape
idx_type = self._normalize_type(buffer_len)
len_symb = self.get_symbol(buffer_len.name, idx_type)
len_symb.add_property(FieldShape(field, 0))
buf_shape = [len_symb, PsConstant(num_entries, idx_type)]
else:
idx_type = DEFAULTS.index_dtype
buf_shape = [
PsConstant(buffer_len, idx_type),
PsConstant(num_entries, idx_type),
]
buf_strides = [PsConstant(num_entries, idx_type), PsConstant(1, idx_type)]
buf_dtype = self.resolve_dynamic_type(field.dtype)
base_ptr = self.get_symbol(
DEFAULTS.field_pointer_name(field.name),
PsPointerType(buf_dtype, restrict=True),
)
return PsBuffer(field.name, buf_dtype, base_ptr, buf_shape, buf_strides)
from typing import overload, cast, Any
from functools import reduce
from operator import add, mul, sub, truediv
import sympy as sp
import sympy.core.relational
import sympy.logic.boolalg
from sympy.codegen.ast import AssignmentBase, AugmentedAssignment
from ...assignment import Assignment
from ...simp import AssignmentCollection
from ...sympyextensions import (
integer_functions,
ConditionalFieldAccess,
)
from ...sympyextensions.typed_sympy import TypedSymbol, TypeCast, DynamicType
from ...sympyextensions.pointers import AddressOf, mem_acc
from ...field import Field, FieldType
from .context import KernelCreationContext
from ..ast.structural import (
PsAstNode,
PsBlock,
PsAssignment,
PsDeclaration,
PsExpression,
PsSymbolExpr,
)
from ..ast.expressions import (
PsBufferAcc,
PsArrayInitList,
PsBitwiseAnd,
PsBitwiseOr,
PsBitwiseXor,
PsAddressOf,
PsCall,
PsCast,
PsConstantExpr,
PsIntDiv,
PsRem,
PsLeftShift,
PsLookup,
PsRightShift,
PsSubscript,
PsTernary,
PsRel,
PsEq,
PsNe,
PsLt,
PsGt,
PsLe,
PsGe,
PsAnd,
PsOr,
PsNot,
PsMemAcc
)
from ..ast.vector import PsVecMemAcc
from ..constants import PsConstant
from ...types import PsNumericType, PsStructType, PsType
from ..exceptions import PsInputError
from ..functions import PsMathFunction, MathFunctions
from ..exceptions import FreezeError
ExprLike = (
sp.Expr
| sp.Tuple
| sympy.core.relational.Relational
| sympy.logic.boolalg.BooleanFunction
)
_ExprLike = (
sp.Expr,
sp.Tuple,
sympy.core.relational.Relational,
sympy.logic.boolalg.BooleanFunction,
)
class FreezeExpressions:
"""Convert expressions and kernels expressed in the SymPy language to the code generator's internal representation.
This class accepts a subset of the SymPy symbolic algebra language complemented with the extensions
implemented in `pystencils.sympyextensions`, and converts it to the abstract syntax tree representation
of the pystencils code generator. It is invoked early during the code generation process.
TODO: Document the full set of supported SymPy features, with restrictions and caveats
TODO: Properly document the SymPy extensions provided by pystencils
"""
def __init__(self, ctx: KernelCreationContext):
self._ctx = ctx
@overload
def __call__(self, obj: AssignmentCollection) -> PsBlock:
pass
@overload
def __call__(self, obj: ExprLike) -> PsExpression:
pass
@overload
def __call__(self, obj: AssignmentBase) -> PsAssignment:
pass
def __call__(self, obj: AssignmentCollection | sp.Basic) -> PsAstNode:
if isinstance(obj, AssignmentCollection):
return PsBlock([self.visit(asm) for asm in obj.all_assignments])
elif isinstance(obj, AssignmentBase):
return cast(PsAssignment, self.visit(obj))
elif isinstance(obj, _ExprLike):
return cast(PsExpression, self.visit(obj))
else:
raise PsInputError(f"Don't know how to freeze {obj}")
def visit(self, node: sp.Basic) -> PsAstNode:
mro = list(type(node).__mro__)
while mro:
method_name = "map_" + mro.pop(0).__name__
try:
method = self.__getattribute__(method_name)
except AttributeError:
pass
else:
return method(node)
raise FreezeError(f"Don't know how to freeze expression {node}")
def visit_expr_or_builtin(self, obj: Any) -> PsExpression:
if isinstance(obj, _ExprLike):
return self.visit_expr(obj)
elif isinstance(obj, (int, float, bool)):
return PsExpression.make(PsConstant(obj))
else:
raise FreezeError(f"Don't know how to freeze {obj}")
def visit_expr(self, expr: sp.Basic):
if not isinstance(expr, _ExprLike):
raise FreezeError(f"Cannot freeze {expr} to an expression")
return cast(PsExpression, self.visit(expr))
def freeze_expression(self, expr: sp.Expr) -> PsExpression:
return cast(PsExpression, self.visit(expr))
def map_Assignment(self, expr: Assignment):
lhs = self.visit(expr.lhs)
rhs = self.visit(expr.rhs)
assert isinstance(lhs, PsExpression)
assert isinstance(rhs, PsExpression)
if isinstance(lhs, PsSymbolExpr):
return PsDeclaration(lhs, rhs)
elif isinstance(lhs, (PsBufferAcc, PsLookup, PsVecMemAcc)):
return PsAssignment(lhs, rhs)
else:
raise FreezeError(
f"Encountered unsupported expression on assignment left-hand side: {lhs}"
)
def map_AugmentedAssignment(self, expr: AugmentedAssignment):
lhs = self.visit(expr.lhs)
rhs = self.visit(expr.rhs)
assert isinstance(lhs, PsExpression)
assert isinstance(rhs, PsExpression)
match expr.op:
case "+=":
op = add
case "-=":
op = sub
case "*=":
op = mul
case "/=":
op = truediv
case _:
raise FreezeError(f"Unsupported augmented assignment: {expr.op}.")
return PsAssignment(lhs, op(lhs.clone(), rhs))
def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr:
symb = self._ctx.get_symbol(spsym.name)
return PsSymbolExpr(symb)
def map_Add(self, expr: sp.Add) -> PsExpression:
# TODO: think about numerically sensible ways of freezing sums and products
frozen_expr = self.visit_expr(expr.args[0])
for summand in expr.args[1:]:
if isinstance(summand, sp.Mul) and any(
factor == -1 for factor in summand.args
):
summand = -summand
op = sub
else:
op = add
frozen_expr = op(frozen_expr, self.visit_expr(summand))
return frozen_expr
def map_Mul(self, expr: sp.Mul) -> PsExpression:
return reduce(mul, (self.visit_expr(arg) for arg in expr.args))
def map_Pow(self, expr: sp.Pow) -> PsExpression:
base = expr.args[0]
exponent = expr.args[1]
expr_frozen = self.visit_expr(base)
if isinstance(exponent, sp.Rational):
# Decompose rational exponent
num: int = exponent.numerator
denom: int = exponent.denominator
if denom <= 2 and abs(num) <= 8:
# At most a square root, and at most eight factors
reciprocal = False
if num < 0:
reciprocal = True
num = -num
if denom == 2:
expr_frozen = PsMathFunction(MathFunctions.Sqrt)(expr_frozen)
denom = 1
assert denom == 1
# Pairwise multiplication for logarithmic runtime
factors = [expr_frozen] + [expr_frozen.clone() for _ in range(num - 1)]
while len(factors) > 1:
combined = [x * y for x, y in zip(factors[::2], factors[1::2])]
if len(factors) % 2 == 1:
combined.append(factors[-1])
factors = combined
expr_frozen = factors.pop()
if reciprocal:
one = PsExpression.make(PsConstant(1))
expr_frozen = one / expr_frozen
return expr_frozen
# If we got this far, use pow
exponent_frozen = self.visit_expr(exponent)
expr_frozen = PsMathFunction(MathFunctions.Pow)(expr_frozen, exponent_frozen)
return expr_frozen
def map_Integer(self, expr: sp.Integer) -> PsConstantExpr:
value = int(expr)
return PsConstantExpr(PsConstant(value))
def map_Float(self, expr: sp.Float) -> PsConstantExpr:
value = float(expr) # TODO: check accuracy of evaluation
return PsConstantExpr(PsConstant(value))
def map_Rational(self, expr: sp.Rational) -> PsExpression:
num = PsConstantExpr(PsConstant(expr.numerator))
denom = PsConstantExpr(PsConstant(expr.denominator))
return num / denom
def map_TypedSymbol(self, expr: TypedSymbol):
dtype = self._ctx.resolve_dynamic_type(expr.dtype)
symb = self._ctx.get_symbol(expr.name, dtype)
return PsSymbolExpr(symb)
def map_Tuple(self, expr: sp.Tuple) -> PsArrayInitList:
if not expr:
raise FreezeError("Cannot translate an empty tuple.")
items = [self.visit_expr(item) for item in expr]
if any(isinstance(i, PsArrayInitList) for i in items):
# base case: have nested arrays
if not all(isinstance(i, PsArrayInitList) for i in items):
raise FreezeError(
f"Cannot translate nested arrays of non-uniform shape: {expr}"
)
subarrays = cast(list[PsArrayInitList], items)
shape_tail = subarrays[0].shape
if not all(s.shape == shape_tail for s in subarrays[1:]):
raise FreezeError(
f"Cannot translate nested arrays of non-uniform shape: {expr}"
)
return PsArrayInitList([s.items_grid for s in subarrays]) # type: ignore
else:
# base case: no nested arrays
return PsArrayInitList(items)
def map_Indexed(self, expr: sp.Indexed) -> PsSubscript:
assert isinstance(expr.base, sp.IndexedBase)
base = self.visit_expr(expr.base.label)
indices = [self.visit_expr(i) for i in expr.indices]
return PsSubscript(base, indices)
def map_Access(self, access: Field.Access):
field = access.field
array = self._ctx.get_buffer(field)
ptr = array.base_pointer
offsets: list[PsExpression] = [
self.visit_expr_or_builtin(o) for o in access.offsets
]
indices: list[PsExpression]
if not access.is_absolute_access:
match field.field_type:
case FieldType.GENERIC | FieldType.CUSTOM:
# Add the iteration counters
offsets = [
PsExpression.make(i) + o
for i, o in zip(
self._ctx.get_iteration_space().spatial_indices, offsets
)
]
case FieldType.INDEXED:
sparse_ispace = self._ctx.get_sparse_iteration_space()
# Add sparse iteration counter to offset
assert len(offsets) == 1 # must have been checked by the context
offsets = [
offsets[0] + PsExpression.make(sparse_ispace.sparse_counter)
]
case FieldType.BUFFER:
ispace = self._ctx.get_full_iteration_space()
compressed_ctr = ispace.compressed_counter()
assert len(offsets) == 1
offsets = [compressed_ctr + offsets[0]]
case unknown:
raise NotImplementedError(
f"Cannot translate accesses to field type {unknown} yet."
)
# If the array type is a struct, accesses are modelled using strings
if isinstance(array.element_type, PsStructType):
if isinstance(access.index, str):
struct_member_name = access.index
indices = [PsExpression.make(PsConstant(0))]
elif len(access.index) == 1 and isinstance(access.index[0], str):
struct_member_name = access.index[0]
indices = [PsExpression.make(PsConstant(0))]
else:
raise FreezeError(
f"Unsupported access into field with struct-type elements: {access}"
)
else:
struct_member_name = None
indices = [self.visit_expr_or_builtin(i) for i in access.index]
if not indices:
# For canonical representation, there must always be at least one index dimension
indices = [PsExpression.make(PsConstant(0))]
if struct_member_name is not None:
# Produce a Lookup here, don't check yet if the member name is valid. That's the typifier's job.
return PsLookup(PsBufferAcc(ptr, offsets + indices), struct_member_name)
else:
return PsBufferAcc(ptr, offsets + indices)
def map_ConditionalFieldAccess(self, acc: ConditionalFieldAccess):
facc = self.visit_expr(acc.access)
condition = self.visit_expr(acc.outofbounds_condition)
fallback = self.visit_expr(acc.outofbounds_value)
return PsTernary(condition, fallback, facc)
def map_Function(self, func: sp.Function) -> PsExpression:
"""Map SymPy function calls by mapping sympy function classes to backend-supported function symbols.
If applicable, functions are mapped to binary operators, e.g. `PsBitwiseXor`.
Other SymPy functions are frozen to an instance of `PsFunction`.
"""
args = tuple(self.visit_expr(arg) for arg in func.args)
match func:
case sp.Abs():
return PsCall(PsMathFunction(MathFunctions.Abs), args)
case sp.floor():
return PsCall(PsMathFunction(MathFunctions.Floor), args)
case sp.ceiling():
return PsCall(PsMathFunction(MathFunctions.Ceil), args)
case sp.exp():
return PsCall(PsMathFunction(MathFunctions.Exp), args)
case sp.log():
return PsCall(PsMathFunction(MathFunctions.Log), args)
case sp.sin():
return PsCall(PsMathFunction(MathFunctions.Sin), args)
case sp.cos():
return PsCall(PsMathFunction(MathFunctions.Cos), args)
case sp.tan():
return PsCall(PsMathFunction(MathFunctions.Tan), args)
case sp.sinh():
return PsCall(PsMathFunction(MathFunctions.Sinh), args)
case sp.cosh():
return PsCall(PsMathFunction(MathFunctions.Cosh), args)
case sp.asin():
return PsCall(PsMathFunction(MathFunctions.ASin), args)
case sp.acos():
return PsCall(PsMathFunction(MathFunctions.ACos), args)
case sp.atan():
return PsCall(PsMathFunction(MathFunctions.ATan), args)
case sp.atan2():
return PsCall(PsMathFunction(MathFunctions.ATan2), args)
case integer_functions.int_div():
return PsIntDiv(*args)
case integer_functions.int_rem():
return PsRem(*args)
case integer_functions.bit_shift_left():
return PsLeftShift(*args)
case integer_functions.bit_shift_right():
return PsRightShift(*args)
case integer_functions.bitwise_and():
return PsBitwiseAnd(*args)
case integer_functions.bitwise_xor():
return PsBitwiseXor(*args)
case integer_functions.bitwise_or():
return PsBitwiseOr(*args)
case integer_functions.int_power_of_2():
return PsLeftShift(PsExpression.make(PsConstant(1)), args[0])
case integer_functions.round_to_multiple_towards_zero():
return PsIntDiv(args[0], args[1]) * args[1]
case integer_functions.ceil_to_multiple():
return (
PsIntDiv(
args[0] + args[1] - PsExpression.make(PsConstant(1)), args[1]
)
* args[1]
)
case integer_functions.div_ceil():
return PsIntDiv(
args[0] + args[1] - PsExpression.make(PsConstant(1)), args[1]
)
case AddressOf():
return PsAddressOf(*args)
case mem_acc():
return PsMemAcc(*args)
case _:
raise FreezeError(f"Unsupported function: {func}")
def map_Piecewise(self, expr: sp.Piecewise) -> PsTernary:
from sympy.functions.elementary.piecewise import ExprCondPair
cases: list[ExprCondPair] = cast(list[ExprCondPair], expr.args)
if cases[-1].cond != sp.true:
raise FreezeError(
"The last case of a `Piecewise` must be the fallback case, its condition must always be `True`."
)
conditions = [self.visit_expr(c.cond) for c in cases[:-1]]
subexprs = [self.visit_expr(c.expr) for c in cases]
last_expr = subexprs.pop()
ternary = PsTernary(conditions.pop(), subexprs.pop(), last_expr)
while conditions:
ternary = PsTernary(conditions.pop(), subexprs.pop(), ternary)
return ternary
def map_Min(self, expr: sp.Min) -> PsCall:
return self._minmax(expr, PsMathFunction(MathFunctions.Min))
def map_Max(self, expr: sp.Max) -> PsCall:
return self._minmax(expr, PsMathFunction(MathFunctions.Max))
def _minmax(self, expr: sp.Min | sp.Max, func: PsMathFunction) -> PsCall:
args = [self.visit_expr(arg) for arg in expr.args]
while len(args) > 1:
args = [
(PsCall(func, (args[i], args[i + 1])) if i + 1 < len(args) else args[i])
for i in range(0, len(args), 2)
]
return cast(PsCall, args[0])
def map_TypeCast(self, cast_expr: TypeCast) -> PsCast | PsConstantExpr:
dtype: PsType
match cast_expr.dtype:
case DynamicType.NUMERIC_TYPE:
dtype = self._ctx.default_dtype
case DynamicType.INDEX_TYPE:
dtype = self._ctx.index_dtype
case other if isinstance(other, PsType):
dtype = other
arg = self.visit_expr(cast_expr.expr)
if (
isinstance(arg, PsConstantExpr)
and arg.constant.dtype is None
and isinstance(dtype, PsNumericType)
):
# As of now, the typifier can not infer the type of a bare constant.
# However, untyped constants may not appear in ASTs from which
# kernel functions are generated. Therefore, we annotate constants
# instead of casting them.
return PsConstantExpr(arg.constant.interpret_as(dtype))
else:
return PsCast(dtype, arg)
def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel:
arg1, arg2 = [self.visit_expr(arg) for arg in rel.args]
match rel.rel_op: # type: ignore
case "==":
return PsEq(arg1, arg2)
case "!=":
return PsNe(arg1, arg2)
case ">=":
return PsGe(arg1, arg2)
case "<=":
return PsLe(arg1, arg2)
case ">":
return PsGt(arg1, arg2)
case "<":
return PsLt(arg1, arg2)
case other:
raise FreezeError(f"Unsupported relation: {other}")
def map_And(self, conj: sympy.logic.And) -> PsAnd:
args = [self.visit_expr(arg) for arg in conj.args]
return reduce(PsAnd, args) # type: ignore
def map_Or(self, disj: sympy.logic.Or) -> PsOr:
args = [self.visit_expr(arg) for arg in disj.args]
return reduce(PsOr, args) # type: ignore
def map_Not(self, neg: sympy.logic.Not) -> PsNot:
arg = self.visit_expr(neg.args[0])
return PsNot(arg)
from __future__ import annotations
from typing import Sequence, TYPE_CHECKING
from abc import ABC
from dataclasses import dataclass
from functools import reduce
from operator import mul
from ...defaults import DEFAULTS
from ...simp import AssignmentCollection
from ...field import Field, FieldType
from ..memory import PsSymbol, PsBuffer
from ..constants import PsConstant
from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem
from ..ast.util import failing_cast
from ...types import PsStructType
from ..exceptions import PsInputError, KernelConstraintsError
if TYPE_CHECKING:
from .context import KernelCreationContext
class IterationSpace(ABC):
"""Represents the n-dimensonal iteration space of a pystencils kernel.
Instances of this class represent the kernel's iteration region during translation from
SymPy, before any indexing sources are generated. It provides the counter symbols which
should be used to translate field accesses to array accesses.
There are two types of iteration spaces, modelled by subclasses:
- The full iteration space translates to an n-dimensional loop nest or the corresponding device
indexing scheme.
- The sparse iteration space translates to a single loop over an index list which in turn provides the
spatial indices.
"""
def __init__(self, spatial_indices: Sequence[PsSymbol]):
if len(spatial_indices) == 0:
raise ValueError("Iteration space must be at least one-dimensional.")
self._spatial_indices = tuple(spatial_indices)
@property
def spatial_indices(self) -> tuple[PsSymbol, ...]:
return self._spatial_indices
@property
def rank(self) -> int:
return len(self._spatial_indices)
class FullIterationSpace(IterationSpace):
"""N-dimensional full iteration space.
Each dimension of the full iteration space is represented by an instance of `FullIterationSpace.Dimension`.
Dimensions are ordered slowest-to-fastest: The first dimension corresponds to the slowest coordinate,
translates to the outermost loop, while the last dimension is the fastest coordinate and translates
to the innermost loop.
"""
@dataclass
class Dimension:
"""One dimension of a dense iteration space"""
start: PsExpression
stop: PsExpression
step: PsExpression
counter: PsSymbol
@staticmethod
def create_with_ghost_layers(
ctx: KernelCreationContext,
ghost_layers: int | Sequence[int | tuple[int, int]],
archetype_field: Field,
) -> FullIterationSpace:
"""Create an iteration space over an archetype field with ghost layers."""
archetype_array = ctx.get_buffer(archetype_field)
dim = archetype_field.spatial_dimensions
counters = [
ctx.get_symbol(name, ctx.index_dtype)
for name in DEFAULTS.spatial_counter_names[:dim]
]
if isinstance(ghost_layers, int):
ghost_layers_spec = [(ghost_layers, ghost_layers) for _ in range(dim)]
else:
if len(ghost_layers) != dim:
raise ValueError("Too few entries in ghost layer spec")
ghost_layers_spec = [
((gl, gl) if isinstance(gl, int) else gl) for gl in ghost_layers
]
one = PsConstantExpr(PsConstant(1, ctx.index_dtype))
ghost_layer_exprs = [
(
PsConstantExpr(PsConstant(gl_left, ctx.index_dtype)),
PsConstantExpr(PsConstant(gl_right, ctx.index_dtype)),
)
for (gl_left, gl_right) in ghost_layers_spec
]
spatial_shape = archetype_array.shape[:dim]
from .typification import Typifier
typify = Typifier(ctx)
dimensions = [
FullIterationSpace.Dimension(
gl_left, typify(PsExpression.make(shape) - gl_right), one, ctr
)
for (gl_left, gl_right), shape, ctr in zip(
ghost_layer_exprs, spatial_shape, counters, strict=True
)
]
return FullIterationSpace(ctx, dimensions, archetype_field=archetype_field)
@staticmethod
def create_from_slice(
ctx: KernelCreationContext,
iteration_slice: int | slice | tuple[int | slice, ...],
archetype_field: Field | None = None,
):
"""Create an iteration space from a sequence of slices, optionally over an archetype field.
Args:
ctx: The kernel creation context
iteration_slice: The iteration slices for each dimension; for valid formats, see `AstFactory.parse_slice`
archetype_field: Optionally, an archetype field that dictates the upper slice limits and loop order.
"""
if not isinstance(iteration_slice, tuple):
iteration_slice = (iteration_slice,)
dim = len(iteration_slice)
if dim == 0:
raise ValueError(
"At least one slice must be specified to create an iteration space"
)
archetype_size: tuple[PsSymbol | PsConstant | None, ...]
if archetype_field is not None:
archetype_array = ctx.get_buffer(archetype_field)
if archetype_field.spatial_dimensions != dim:
raise ValueError(
f"Number of dimensions in slice ({len(iteration_slice)}) "
f" did not equal iteration space dimensionality ({dim})"
)
archetype_size = tuple(archetype_array.shape[:dim])
else:
archetype_size = (None,) * dim
counters = [
ctx.get_symbol(name, ctx.index_dtype)
for name in DEFAULTS.spatial_counter_names[:dim]
]
from .ast_factory import AstFactory
factory = AstFactory(ctx)
def to_dim(
slic: int | slice, size: PsSymbol | PsConstant | None, ctr: PsSymbol
):
start, stop, step = factory.parse_slice(slic, size)
return FullIterationSpace.Dimension(start, stop, step, ctr)
dimensions = [
to_dim(slic, size, ctr)
for slic, size, ctr in zip(
iteration_slice, archetype_size, counters, strict=True
)
]
return FullIterationSpace(ctx, dimensions, archetype_field=archetype_field)
def __init__(
self,
ctx: KernelCreationContext,
dimensions: Sequence[FullIterationSpace.Dimension],
archetype_field: Field | None = None,
):
super().__init__(tuple(dim.counter for dim in dimensions))
self._ctx = ctx
self._dimensions = dimensions
self._archetype_field = archetype_field
@property
def dimensions(self):
"""The dimensions of this iteration space"""
return self._dimensions
@property
def counters(self) -> tuple[PsSymbol, ...]:
return tuple(dim.counter for dim in self._dimensions)
@property
def lower(self) -> tuple[PsExpression, ...]:
"""Lower limits of each dimension"""
return tuple(dim.start for dim in self._dimensions)
@property
def upper(self) -> tuple[PsExpression, ...]:
"""Upper limits of each dimension"""
return tuple(dim.stop for dim in self._dimensions)
@property
def steps(self) -> tuple[PsExpression, ...]:
"""Iteration steps of each dimension"""
return tuple(dim.step for dim in self._dimensions)
@property
def archetype_field(self) -> Field | None:
"""Field whose shape and memory layout act as archetypes for this iteration space's dimensions."""
return self._archetype_field
@property
def loop_order(self) -> tuple[int, ...]:
"""Return the loop order of this iteration space, ordered from slowest to fastest coordinate."""
if self._archetype_field is not None:
return self._archetype_field.layout
else:
return tuple(range(len(self.dimensions)))
def dimensions_in_loop_order(self) -> Sequence[FullIterationSpace.Dimension]:
"""Return the dimensions of this iteration space ordered from the slowest to the fastest coordinate.
If this iteration space has an `archetype field <FullIterationSpace.archetype_field>` set,
its field layout is used to determine the ideal loop order;
otherwise, the dimensions are returned as they are
"""
return [self._dimensions[i] for i in self.loop_order]
def actual_iterations(
self, dimension: int | FullIterationSpace.Dimension | None = None
) -> PsExpression:
"""Construct an expression representing the actual number of unique points inside the iteration space.
Args:
dimension: If an integer or a `Dimension` object is given, the number of iterations in that
dimension is computed. If `None`, the total number of iterations inside the entire space
is computed.
"""
from .typification import Typifier
from ..transformations import EliminateConstants
typify = Typifier(self._ctx)
fold = EliminateConstants(self._ctx)
if dimension is None:
return fold(
typify(
reduce(
mul,
(
self.actual_iterations(d)
for d in range(len(self.dimensions))
),
)
)
)
else:
if isinstance(dimension, FullIterationSpace.Dimension):
dim = dimension
else:
dim = self.dimensions[dimension]
one = PsConstantExpr(PsConstant(1, self._ctx.index_dtype))
zero = PsConstantExpr(PsConstant(0, self._ctx.index_dtype))
return fold(
typify(
PsTernary(
PsEq(PsRem((dim.stop - dim.start), dim.step), zero),
(dim.stop - dim.start) / dim.step,
(dim.stop - dim.start) / dim.step + one,
)
)
)
def compressed_counter(self) -> PsExpression:
"""Expression counting the actual number of items processed at the iteration defined by the counter tuple.
Used primarily for indexing buffers."""
actual_iters = [self.actual_iterations(d) for d in range(self.rank)]
compressed_counters = [
(PsExpression.make(dim.counter) - dim.start) / dim.step
for dim in self.dimensions
]
compressed_idx = compressed_counters[0]
for ctr, iters in zip(compressed_counters[1:], actual_iters[1:]):
compressed_idx = compressed_idx * iters + ctr
return compressed_idx
class SparseIterationSpace(IterationSpace):
"""Represents a sparse iteration space defined by an index list."""
def __init__(
self,
spatial_indices: Sequence[PsSymbol],
index_list: PsBuffer,
coordinate_members: Sequence[PsStructType.Member],
sparse_counter: PsSymbol,
):
super().__init__(spatial_indices)
self._index_list = index_list
self._coord_members = tuple(coordinate_members)
self._sparse_counter = sparse_counter
@property
def index_list(self) -> PsBuffer:
return self._index_list
@property
def coordinate_members(self) -> tuple[PsStructType.Member, ...]:
return self._coord_members
@property
def sparse_counter(self) -> PsSymbol:
return self._sparse_counter
def get_archetype_field(
fields: set[Field],
check_compatible_shapes: bool = True,
check_same_layouts: bool = True,
check_same_dimensions: bool = True,
):
"""Retrieve an archetype field from a collection of fields, which represents their common properties.
Raises:
KernelConstrainsError: If any two fields with conflicting properties are encountered.
"""
shapes = set(f.spatial_shape for f in fields)
fixed_shapes = set(f.spatial_shape for f in fields if f.has_fixed_shape)
layouts = set(f.layout for f in fields)
dimensionalities = set(f.spatial_dimensions for f in fields)
if check_same_dimensions and len(dimensionalities) != 1:
raise KernelConstraintsError(
"All fields must have the same number of spatial dimensions."
)
if check_same_layouts and len(layouts) != 1:
raise KernelConstraintsError("All fields must have the same memory layout.")
if check_compatible_shapes:
if len(fixed_shapes) > 0:
if len(fixed_shapes) != len(shapes):
raise KernelConstraintsError(
"Cannot mix fixed- and variable-shape fields."
)
if len(fixed_shapes) > 1:
raise KernelConstraintsError(
"Fixed-shape fields of different sizes encountered."
)
archetype_field = sorted(fields, key=lambda f: str(f))[0]
return archetype_field
def create_sparse_iteration_space(
ctx: KernelCreationContext,
assignments: AssignmentCollection,
index_field: Field | None = None,
) -> IterationSpace:
# All domain and custom fields must have the same spatial dimensions
# TODO: Must all domain fields have the same shape?
archetype_field = get_archetype_field(
ctx.fields.domain_fields | ctx.fields.custom_fields,
check_compatible_shapes=False,
check_same_layouts=False,
check_same_dimensions=True,
)
dim = archetype_field.spatial_dimensions
coord_members = [
PsStructType.Member(name, ctx.index_dtype)
for name in DEFAULTS.index_struct_coordinate_names[:dim]
]
# Determine index field
if index_field is not None:
idx_arr = ctx.get_buffer(index_field)
idx_struct_type: PsStructType = failing_cast(PsStructType, idx_arr.element_type)
for coord in coord_members:
if coord not in idx_struct_type.members:
raise PsInputError(
f"Given index field does not provide required coordinate member {coord}"
)
else:
# TODO: Find index field from the fields list
raise NotImplementedError(
"Automatic inference of index field for sparse iteration not supported yet."
)
spatial_counters = [
ctx.get_symbol(name, ctx.index_dtype)
for name in DEFAULTS.spatial_counter_names[:dim]
]
sparse_counter = ctx.get_symbol(DEFAULTS.sparse_counter_name, ctx.index_dtype)
return SparseIterationSpace(
spatial_counters, idx_arr, coord_members, sparse_counter
)
def create_full_iteration_space(
ctx: KernelCreationContext,
assignments: AssignmentCollection,
ghost_layers: None | int | Sequence[int | tuple[int, int]] = None,
iteration_slice: None | int | slice | tuple[int | slice, ...] = None,
infer_ghost_layers: bool = False,
) -> IterationSpace:
"""Create a dense iteration space from a sequence of assignments and iteration slice information.
This function finds all accesses to fields in the given assignment collection,
analyzes the set of fields involved,
and determines the iteration space bounds from these.
This requires that either all fields are of the same, fixed, shape, or all of them are
variable-shaped.
Also, all fields need to have the same memory layout of their spatial dimensions.
Args:
ctx: The kernel creation context
assignments: Collection of assignments the iteration space should be inferred from
ghost_layers: If set, strip off that many ghost layers from all sides of the iteration cuboid
iteration_slice: If set, constrain iteration to the given slice.
For details on the parsing of slices, see `AstFactory.parse_slice`.
infer_ghost_layers: If `True`, infer the number of ghost layers from the stencil ranges
used in the kernel.
Returns:
IterationSpace: The constructed iteration space.
Raises:
KernelConstraintsError: If field shape or memory layout conflicts are detected
ValueError: If the iteration slice could not be parsed
.. attention::
The ``ghost_layers`` and ``iteration_slice`` arguments are mutually exclusive.
Also, if ``infer_ghost_layers=True``, none of them may be set.
"""
assert not ctx.fields.index_fields
if (ghost_layers is None) and (iteration_slice is None) and not infer_ghost_layers:
raise ValueError(
"One argument of `ghost_layers`, `iteration_slice`, and `infer_ghost_layers` must be set."
)
if (
int(ghost_layers is not None)
+ int(iteration_slice is not None)
+ int(infer_ghost_layers)
> 1
):
raise ValueError(
"At most one of `ghost_layers`, `iteration_slice`, and `infer_ghost_layers` may be set."
)
# Collect all relative accesses into domain fields
def access_filter(acc: Field.Access):
return acc.field.field_type in (
FieldType.GENERIC,
FieldType.STAGGERED,
FieldType.STAGGERED_FLUX,
)
domain_field_accesses = assignments.atoms(Field.Access)
domain_field_accesses = set(filter(access_filter, domain_field_accesses))
# The following scenarios exist:
# - We have at least one domain field -> find the common field and use it to determine the iteration region
# - We have no domain fields, but at least one custom field -> determine common field from custom fields
# - We have neither domain nor custom fields -> Error
if len(domain_field_accesses) > 0:
archetype_field = get_archetype_field(ctx.fields.domain_fields)
elif len(ctx.fields.custom_fields) > 0:
# TODO: Warn about inferring iteration space from custom fields
archetype_field = get_archetype_field(ctx.fields.custom_fields)
else:
raise PsInputError(
"Unable to construct iteration space: The kernel contains no accesses to domain or custom fields."
)
# If the user provided a ghost layer specification, use that
# Otherwise, if an iteration slice was specified, use that
# Otherwise, use the inferred ghost layers
if infer_ghost_layers:
if len(domain_field_accesses) > 0:
inferred_gls = max(
[fa.required_ghost_layers for fa in domain_field_accesses]
)
else:
inferred_gls = 0
ctx.metadata["ghost_layers"] = inferred_gls
return FullIterationSpace.create_with_ghost_layers(
ctx, inferred_gls, archetype_field
)
elif ghost_layers is not None:
ctx.metadata["ghost_layers"] = ghost_layers
return FullIterationSpace.create_with_ghost_layers(
ctx, ghost_layers, archetype_field
)
elif iteration_slice is not None:
return FullIterationSpace.create_from_slice(
ctx, iteration_slice, archetype_field
)
else:
assert False, "unreachable code"
from __future__ import annotations
from typing import TypeVar, Callable
from .context import KernelCreationContext
from ...types import (
PsType,
PsNumericType,
PsStructType,
PsIntegerType,
PsArrayType,
PsDereferencableType,
PsPointerType,
PsBoolType,
PsScalarType,
PsVectorType,
constify,
deconstify,
)
from ..ast.structural import (
PsAstNode,
PsBlock,
PsLoop,
PsConditional,
PsExpression,
PsAssignment,
PsDeclaration,
PsStatement,
PsEmptyLeafMixIn,
)
from ..ast.expressions import (
PsBufferAcc,
PsArrayInitList,
PsBinOp,
PsIntOpTrait,
PsNumericOpTrait,
PsBoolOpTrait,
PsCall,
PsTernary,
PsCast,
PsAddressOf,
PsConstantExpr,
PsLookup,
PsSubscript,
PsMemAcc,
PsSymbolExpr,
PsLiteralExpr,
PsRel,
PsNeg,
PsNot,
)
from ..ast.vector import PsVecBroadcast, PsVecMemAcc
from ..functions import PsMathFunction, CFunction
from ..ast.util import determine_memory_object
from ..exceptions import TypificationError
__all__ = ["Typifier"]
NodeT = TypeVar("NodeT", bound=PsAstNode)
ResolutionHook = Callable[[PsType], None]
class TypeContext:
"""Typing context, with support for type inference and checking.
Instances of this class are used to propagate and check data types across expression subtrees
of the AST. Each type context has a target type `target_type`, which shall be applied to all expressions it covers
"""
def __init__(
self,
target_type: PsType | None = None,
):
self._deferred_exprs: list[PsExpression] = []
self._target_type = deconstify(target_type) if target_type is not None else None
self._hooks: list[ResolutionHook] = []
@property
def target_type(self) -> PsType | None:
"""The target type of this type context."""
return self._target_type
def add_hook(self, hook: ResolutionHook):
"""Adds a resolution hook to this context.
The hook will be called with the context's target type as soon as it becomes known,
which might be immediately.
"""
if self._target_type is None:
self._hooks.append(hook)
else:
hook(self._target_type)
def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None):
"""Applies the given ``dtype`` to this type context, and optionally to the given expression.
If the context's target_type is already known, it must be compatible with the given dtype.
If the target type is still unknown, target_type is set to dtype and retroactively applied
to all deferred expressions.
If an expression is specified, it will be covered by the type context.
If the expression already has a data type set, it must be compatible with the target type
and will be replaced by it.
"""
dtype = deconstify(dtype)
if self._target_type is not None and dtype != self._target_type:
raise TypificationError(
f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
f" Expression type: {dtype}\n"
f" Target type: {self._target_type}"
)
else:
self._target_type = dtype
self._propagate_target_type()
if expr is not None:
self._apply_target_type(expr)
def infer_dtype(self, expr: PsExpression):
"""Infer the data type for the given expression.
If the target_type of this context is already known, it will be applied to the given expression.
Otherwise, the expression is deferred, and a type will be applied to it as soon as `apply_dtype` is
called on this context.
If the expression already has a data type set, it must be compatible with the target type
and will be replaced by it.
"""
if self._target_type is None:
self._deferred_exprs.append(expr)
else:
self._apply_target_type(expr)
def _propagate_target_type(self):
assert self._target_type is not None
for hook in self._hooks:
hook(self._target_type)
self._hooks = []
for expr in self._deferred_exprs:
self._apply_target_type(expr)
self._deferred_exprs = []
def _apply_target_type(self, expr: PsExpression):
assert self._target_type is not None
if expr.dtype is not None:
if not self._compatible(expr.dtype):
raise TypificationError(
f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
f" Expression type: {expr.dtype}\n"
f" Target type: {self._target_type}"
)
else:
match expr:
case PsConstantExpr(c):
if not isinstance(self._target_type, PsNumericType):
raise TypificationError(
f"Can't typify constant with non-numeric type {self._target_type}"
)
if c.dtype is None:
expr.constant = c.interpret_as(self._target_type)
elif not self._compatible(c.dtype):
raise TypificationError(
f"Type mismatch at constant {c}: Constant type did not match the context's target type\n"
f" Constant type: {c.dtype}\n"
f" Target type: {self._target_type}"
)
case PsLiteralExpr(lit):
if not self._compatible(lit.dtype):
raise TypificationError(
f"Type mismatch at literal {lit}: Literal type did not match the context's target type\n"
f" Literal type: {lit.dtype}\n"
f" Target type: {self._target_type}"
)
case PsSymbolExpr(symb):
if symb.dtype is None:
# Symbols are not forced to constness
symb.dtype = deconstify(self._target_type)
elif not self._compatible(symb.dtype):
raise TypificationError(
f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n"
f" Symbol type: {symb.dtype}\n"
f" Target type: {self._target_type}"
)
case PsNumericOpTrait() if not isinstance(
self._target_type, PsNumericType
) or self._target_type.is_bool():
# FIXME: PsBoolType derives from PsNumericType, but is not numeric
raise TypificationError(
f"Numerical operation encountered in non-numerical type context:\n"
f" Expression: {expr}"
f" Type Context: {self._target_type}"
)
case PsIntOpTrait() if not (
isinstance(self._target_type, PsNumericType)
and self._target_type.is_int()
):
raise TypificationError(
f"Integer operation encountered in non-integer type context:\n"
f" Expression: {expr}"
f" Type Context: {self._target_type}"
)
case PsBoolOpTrait() if not (
isinstance(self._target_type, PsNumericType)
and self._target_type.is_bool()
):
raise TypificationError(
f"Boolean operation encountered in non-boolean type context:\n"
f" Expression: {expr}"
f" Type Context: {self._target_type}"
)
# endif
expr.dtype = self._target_type
def _compatible(self, dtype: PsType):
"""Checks whether the given data type is compatible with the context's target type.
The two must match except for constness.
"""
assert self._target_type is not None
return deconstify(dtype) == self._target_type
class Typifier:
"""Apply data types to expressions.
**Contextual Typing**
The Typifier will traverse the AST and apply a contextual typing scheme to figure out
the data types of all encountered expressions.
To this end, it covers each expression tree with a set of disjoint typing contexts.
All nodes covered by the same typing context must have the same type.
Starting from an expression's root, a typing context is implicitly expanded through
the recursive descent into a node's children. In particular, a child is typified within
the same context as its parent if the node's semantics require parent and child to have
the same type (e.g. at arithmetic operators, mathematical functions, etc.).
If a node's child is required to have a different type, a new context is opened.
For each typing context, its target type is prescribed by the first node encountered during traversal
whose type is fixed according to its typing rules. All other nodes covered by the context must share
that type.
The types of arithmetic operators, mathematical functions, and untyped constants are
inferred from their context's target type. If one of these is encountered while no target type is set yet
in the context, the expression is deferred by storing it in the context, and will be assigned a type as soon
as the target type is fixed.
**Typing Rules**
The following general rules apply:
- The context's ``default_dtype`` is applied to all untyped symbols encountered inside a right-hand side expression
- If an untyped symbol is encountered on an assignment's left-hand side, it will first be attempted to infer its
type from the right-hand side. If that fails, the context's ``default_dtype`` will be applied.
- It is an error if an untyped symbol occurs in the same type context as a typed symbol or constant
with a non-default data type.
- By default, all expressions receive a ``const`` type unless they occur on a (non-declaration) assignment's
left-hand side
**Typing of symbol expressions**
Some expressions (`PsSymbolExpr`, `PsBufferAcc`) encapsulate symbols and inherit their data types.
"""
def __init__(self, ctx: KernelCreationContext):
self._ctx = ctx
def __call__(self, node: NodeT) -> NodeT:
if isinstance(node, PsExpression):
tc = TypeContext()
self.visit_expr(node, tc)
if tc.target_type is None:
# no type could be inferred -> take the default
tc.apply_dtype(self._ctx.default_dtype)
else:
self.visit(node)
return node
def typify_expression(
self, expr: PsExpression, target_type: PsType | None = None
) -> tuple[PsExpression, PsType]:
tc = TypeContext(target_type)
self.visit_expr(expr, tc)
if tc.target_type is None:
raise TypificationError(f"Unable to determine type for {expr}")
return expr, tc.target_type
def visit(self, node: PsAstNode) -> None:
"""Recursive processing of structural nodes"""
match node:
case PsBlock([*statements]):
for s in statements:
self.visit(s)
case PsStatement(expr):
tc = TypeContext()
self.visit_expr(expr, tc)
if tc.target_type is None:
tc.apply_dtype(self._ctx.default_dtype)
case PsDeclaration(lhs, rhs) if isinstance(rhs, PsArrayInitList):
# Special treatment for array declarations
assert isinstance(lhs, PsSymbolExpr)
decl_tc = TypeContext()
items_tc = TypeContext()
if (lhs_type := lhs.symbol.dtype) is not None:
if not isinstance(lhs_type, PsArrayType):
raise TypificationError(
f"Illegal LHS type in array declaration: {lhs_type}"
)
if lhs_type.shape != rhs.shape:
raise TypificationError(
f"Incompatible shapes in declaration of array symbol {lhs.symbol}.\n"
f" Symbol shape: {lhs_type.shape}\n"
f" Array shape: {rhs.shape}"
)
items_tc.apply_dtype(lhs_type.base_type)
decl_tc.apply_dtype(lhs_type, lhs)
else:
decl_tc.infer_dtype(lhs)
for item in rhs.items:
self.visit_expr(item, items_tc)
if items_tc.target_type is None:
items_tc.apply_dtype(self._ctx.default_dtype)
if decl_tc.target_type is None:
assert items_tc.target_type is not None
decl_tc.apply_dtype(
PsArrayType(items_tc.target_type, rhs.shape), rhs
)
else:
decl_tc.infer_dtype(rhs)
case PsDeclaration(lhs, rhs) | PsAssignment(lhs, rhs):
# Only if the LHS is an untyped symbol, infer its type from the RHS
infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None
tc = TypeContext()
if infer_lhs:
tc.infer_dtype(lhs)
else:
self.visit_expr(lhs, tc)
assert tc.target_type is not None
self.visit_expr(rhs, tc)
if infer_lhs and tc.target_type is None:
# no type has been inferred -> use the default dtype
tc.apply_dtype(self._ctx.default_dtype)
elif not isinstance(node, PsDeclaration):
# check mutability of LHS
_, lhs_const = determine_memory_object(lhs)
if lhs_const:
raise TypificationError(f"Cannot assign to immutable LHS {lhs}")
case PsConditional(cond, branch_true, branch_false):
cond_tc = TypeContext(PsBoolType())
self.visit_expr(cond, cond_tc)
self.visit(branch_true)
if branch_false is not None:
self.visit(branch_false)
case PsLoop(ctr, start, stop, step, body):
if ctr.symbol.dtype is None:
ctr.symbol.apply_dtype(self._ctx.index_dtype)
ctr.dtype = ctr.symbol.get_dtype()
tc_index = TypeContext(ctr.symbol.dtype)
self.visit_expr(start, tc_index)
self.visit_expr(stop, tc_index)
self.visit_expr(step, tc_index)
self.visit(body)
case PsEmptyLeafMixIn():
pass
case _:
raise NotImplementedError(f"Can't typify {node}")
def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None:
"""Recursive processing of expression nodes.
This method opens, expands, and closes typing contexts according to the respective expression's
typing rules. It may add or check restrictions only when opening or closing a type context.
The actual type inference and checking during context expansion are performed by the methods
of `TypeContext`. ``visit_expr`` tells the typing context how to handle an expression by calling
either ``apply_dtype`` or ``infer_dtype``.
"""
match expr:
case PsSymbolExpr(symb):
if symb.dtype is None:
symb.dtype = self._ctx.default_dtype
tc.apply_dtype(symb.dtype, expr)
case PsConstantExpr(c):
if c.dtype is not None:
tc.apply_dtype(c.dtype, expr)
else:
tc.infer_dtype(expr)
case PsLiteralExpr(lit):
tc.apply_dtype(lit.dtype, expr)
case PsBufferAcc(_, indices):
tc.apply_dtype(expr.buffer.element_type, expr)
for idx in indices:
self._handle_idx(idx)
case PsMemAcc(ptr, offset) | PsVecMemAcc(ptr, offset):
ptr_tc = TypeContext()
self.visit_expr(ptr, ptr_tc)
if not isinstance(ptr_tc.target_type, PsPointerType):
raise TypificationError(
f"Type of pointer argument to memory access was not a pointer type: {ptr_tc.target_type}"
)
tc.apply_dtype(ptr_tc.target_type.base_type, expr)
self._handle_idx(offset)
if isinstance(expr, PsVecMemAcc) and expr.stride is not None:
self._handle_idx(expr.stride)
case PsSubscript(arr, indices):
if isinstance(arr, PsArrayInitList):
shape = arr.shape
# extend outer context over the init-list entries
for item in arr.items:
self.visit_expr(item, tc)
# learn the array type from the items
def arr_hook(element_type: PsType):
arr.dtype = PsArrayType(element_type, arr.shape)
tc.add_hook(arr_hook)
else:
# type of array has to be known
arr_tc = TypeContext()
self.visit_expr(arr, arr_tc)
if not isinstance(arr_tc.target_type, PsArrayType):
raise TypificationError(
f"Type of array argument to subscript was not an array type: {arr_tc.target_type}"
)
tc.apply_dtype(arr_tc.target_type.base_type, expr)
shape = arr_tc.target_type.shape
if len(indices) != len(shape):
raise TypificationError(
f"Invalid number of indices to {len(shape)}-dimensional array: {len(indices)}"
)
for idx in indices:
self._handle_idx(idx)
case PsAddressOf(arg):
if not isinstance(
arg, (PsSymbolExpr, PsSubscript, PsMemAcc, PsBufferAcc, PsLookup)
):
raise TypificationError(
f"Illegal expression below AddressOf operator: {arg}"
)
arg_tc = TypeContext()
self.visit_expr(arg, arg_tc)
if arg_tc.target_type is None:
raise TypificationError(
f"Unable to determine type of argument to AddressOf: {arg}"
)
# Inherit pointed-to type from referenced object, not from the subexpression
match arg:
case PsSymbolExpr(s):
pointed_to_type = s.get_dtype()
case PsSubscript(ptr, _) | PsMemAcc(ptr, _) | PsBufferAcc(ptr, _):
arr_type = ptr.get_dtype()
assert isinstance(arr_type, PsDereferencableType)
pointed_to_type = arr_type.base_type
case PsLookup(aggr, member_name):
struct_type = aggr.get_dtype()
assert isinstance(struct_type, PsStructType)
if struct_type.const:
pointed_to_type = constify(
struct_type.get_member(member_name).dtype
)
else:
pointed_to_type = deconstify(
struct_type.get_member(member_name).dtype
)
case _:
assert False, "unreachable code"
ptr_type = PsPointerType(pointed_to_type, const=True)
tc.apply_dtype(ptr_type, expr)
case PsLookup(aggr, member_name):
# Members of a struct type inherit the struct type's `const` qualifier
aggr_tc = TypeContext()
self.visit_expr(aggr, aggr_tc)
aggr_type = aggr_tc.target_type
if not isinstance(aggr_type, PsStructType):
raise TypificationError(
"Aggregate type of lookup is not a struct type."
)
member = aggr_type.find_member(member_name)
if member is None:
raise TypificationError(
f"Aggregate of type {aggr_type} does not have a member {member_name}."
)
member_type = member.dtype
if aggr_type.const:
member_type = constify(member_type)
tc.apply_dtype(member_type, expr)
case PsTernary(cond, then, els):
cond_tc = TypeContext(target_type=PsBoolType())
self.visit_expr(cond, cond_tc)
self.visit_expr(then, tc)
self.visit_expr(els, tc)
tc.infer_dtype(expr)
case PsRel(op1, op2):
args_tc = TypeContext()
self.visit_expr(op1, args_tc)
self.visit_expr(op2, args_tc)
if args_tc.target_type is None:
raise TypificationError(
f"Unable to determine type of arguments to relation: {expr}"
)
if not isinstance(args_tc.target_type, PsNumericType):
raise TypificationError(
f"Invalid type in arguments to relation\n"
f" Expression: {expr}\n"
f" Arguments Type: {args_tc.target_type}"
)
if isinstance(args_tc.target_type, PsVectorType):
tc.apply_dtype(
PsVectorType(PsBoolType(), args_tc.target_type.vector_entries),
expr,
)
else:
tc.apply_dtype(PsBoolType(), expr)
case PsBinOp(op1, op2):
self.visit_expr(op1, tc)
self.visit_expr(op2, tc)
tc.infer_dtype(expr)
case PsNeg(op) | PsNot(op):
self.visit_expr(op, tc)
tc.infer_dtype(expr)
case PsCall(function, args):
match function:
case PsMathFunction():
for arg in args:
self.visit_expr(arg, tc)
tc.infer_dtype(expr)
case CFunction(_, arg_types, ret_type):
tc.apply_dtype(ret_type, expr)
for arg, arg_type in zip(args, arg_types, strict=True):
arg_tc = TypeContext(arg_type)
self.visit_expr(arg, arg_tc)
case _:
raise TypificationError(
f"Don't know how to typify calls to {function}"
)
case PsArrayInitList(_):
raise TypificationError(
"Unable to typify array initializer in isolation.\n"
f" Array: {expr}"
)
case PsCast(dtype, arg):
arg_tc = TypeContext()
self.visit_expr(arg, arg_tc)
if arg_tc.target_type is None:
raise TypificationError(
f"Unable to determine type of argument to Cast: {arg}"
)
tc.apply_dtype(dtype, expr)
case PsVecBroadcast(lanes, arg):
op_tc = TypeContext()
self.visit_expr(arg, op_tc)
if op_tc.target_type is None:
raise TypificationError(
f"Unable to determine type of argument to vector broadcast: {arg}"
)
if not isinstance(op_tc.target_type, PsScalarType):
raise TypificationError(
f"Illegal type in argument to vector broadcast: {op_tc.target_type}"
)
tc.apply_dtype(PsVectorType(op_tc.target_type, lanes), expr)
case _:
raise NotImplementedError(f"Can't typify {expr}")
def _handle_idx(self, idx: PsExpression):
index_tc = TypeContext()
self.visit_expr(idx, index_tc)
if index_tc.target_type is None:
index_tc.apply_dtype(self._ctx.index_dtype, idx)
elif not isinstance(index_tc.target_type, PsIntegerType):
raise TypificationError(
f"Invalid data type in index expression.\n"
f" Expression: {idx}\n"
f" Type: {index_tc.target_type}"
)
from __future__ import annotations
from ..types import PsType, constify
class PsLiteral:
"""Representation of literal code.
Instances of this class represent code literals inside the AST.
These literals are not to be confused with C literals; the name 'Literal' refers to the fact that
the code generator takes them "literally", printing them as they are.
Each literal has to be annotated with a type, and is considered constant within the scope of a kernel.
Instances of `PsLiteral` are immutable.
"""
__match_args__ = ("text", "dtype")
def __init__(self, text: str, dtype: PsType) -> None:
self._text = text
self._dtype = constify(dtype)
@property
def text(self) -> str:
return self._text
@property
def dtype(self) -> PsType:
return self._dtype
def __str__(self) -> str:
return f"{self._text}: {self._dtype}"
def __repr__(self) -> str:
return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})"
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsLiteral):
return False
return self._text == other._text and self._dtype == other._dtype
def __hash__(self) -> int:
return hash((PsLiteral, self._text, self._dtype))
from __future__ import annotations
from typing import Sequence
from itertools import chain
from dataclasses import dataclass
from ..types import PsType, PsTypeError, deconstify, PsIntegerType, PsPointerType
from .exceptions import PsInternalCompilerError
from .constants import PsConstant
from ..codegen.properties import PsSymbolProperty, UniqueSymbolProperty
class PsSymbol:
"""A mutable symbol with name and data type.
Do not create objects of this class directly unless you know what you are doing;
instead obtain them from a `KernelCreationContext` through `KernelCreationContext.get_symbol`.
This way, the context can keep track of all symbols used in the translation run,
and uniqueness of symbols is ensured.
"""
__match_args__ = ("name", "dtype")
def __init__(self, name: str, dtype: PsType | None = None):
self._name = name
self._dtype = dtype
self._properties: set[PsSymbolProperty] = set()
@property
def name(self) -> str:
return self._name
@property
def dtype(self) -> PsType | None:
return self._dtype
@dtype.setter
def dtype(self, value: PsType):
self._dtype = value
def apply_dtype(self, dtype: PsType):
"""Apply the given data type to this symbol,
raising a TypeError if it conflicts with a previously set data type."""
if self._dtype is not None and self._dtype != dtype:
raise PsTypeError(
f"Incompatible symbol data types: {self._dtype} and {dtype}"
)
self._dtype = dtype
def get_dtype(self) -> PsType:
if self._dtype is None:
raise PsInternalCompilerError(
f"Symbol {self.name} had no type assigned yet"
)
return self._dtype
@property
def properties(self) -> frozenset[PsSymbolProperty]:
"""Set of properties attached to this symbol"""
return frozenset(self._properties)
def get_properties(
self, prop_type: type[PsSymbolProperty]
) -> set[PsSymbolProperty]:
"""Retrieve all properties of the given type attached to this symbol"""
return set(filter(lambda p: isinstance(p, prop_type), self._properties))
def add_property(self, property: PsSymbolProperty):
"""Attach a property to this symbol"""
if isinstance(property, UniqueSymbolProperty) and not self.get_properties(
type(property)
) <= {property}:
raise ValueError(
f"Cannot add second instance of unique property {type(property)} to symbol {self._name}."
)
self._properties.add(property)
def remove_property(self, property: PsSymbolProperty):
"""Remove a property from this symbol. Does nothing if the property is not attached."""
self._properties.discard(property)
def __str__(self) -> str:
dtype_str = "<untyped>" if self._dtype is None else str(self._dtype)
return f"{self._name}: {dtype_str}"
def __repr__(self) -> str:
return f"PsSymbol({repr(self._name)}, {repr(self._dtype)})"
@dataclass(frozen=True)
class BufferBasePtr(UniqueSymbolProperty):
"""Symbol acts as a base pointer to a buffer."""
buffer: PsBuffer
class PsBuffer:
"""N-dimensional contiguous linearized buffer in heap memory.
`PsBuffer` models the memory buffers underlying the `Field` class
to the backend. Each buffer represents a contiguous block of memory
that is non-aliased and disjoint from all other buffers.
Buffer shape and stride information are given either as constants or as symbols.
All indexing expressions must have the same data type, which will be selected as the buffer's
``index_dtype <PsBuffer.index_dtype>``.
Each buffer has at least one base pointer, which can be retrieved via the `PsBuffer.base_pointer`
property.
"""
def __init__(
self,
name: str,
element_type: PsType,
base_ptr: PsSymbol,
shape: Sequence[PsSymbol | PsConstant],
strides: Sequence[PsSymbol | PsConstant],
):
bptr_type = base_ptr.get_dtype()
if not isinstance(bptr_type, PsPointerType):
raise ValueError(
f"Type of buffer base pointer {base_ptr} was not a pointer type: {bptr_type}"
)
if bptr_type.base_type != element_type:
raise ValueError(
f"Base type of primary buffer base pointer {base_ptr} "
f"did not equal buffer element type {element_type}."
)
if len(shape) != len(strides):
raise ValueError("Buffer shape and stride tuples must have the same length")
idx_types: set[PsType] = set(
deconstify(s.get_dtype()) for s in chain(shape, strides)
)
if len(idx_types) > 1:
raise ValueError(
f"Conflicting data types in indexing symbols to buffer {name}: {idx_types}"
)
idx_dtype = idx_types.pop()
if not isinstance(idx_dtype, PsIntegerType):
raise ValueError(
f"Invalid index data type for buffer {name}: {idx_dtype}. Must be an integer type."
)
self._name = name
self._element_type = element_type
self._index_dtype = idx_dtype
self._shape = list(shape)
self._strides = list(strides)
base_ptr.add_property(BufferBasePtr(self))
self._base_ptr = base_ptr
@property
def name(self):
"""The buffer's name"""
return self._name
@property
def base_pointer(self) -> PsSymbol:
"""Primary base pointer"""
return self._base_ptr
@property
def shape(self) -> list[PsSymbol | PsConstant]:
"""Buffer shape symbols and/or constants"""
return self._shape
@property
def strides(self) -> list[PsSymbol | PsConstant]:
"""Buffer stride symbols and/or constants"""
return self._strides
@property
def dim(self) -> int:
"""Dimensionality of this buffer"""
return len(self._shape)
@property
def index_type(self) -> PsIntegerType:
"""Index data type of this buffer; i.e. data type of its shape and stride symbols"""
return self._index_dtype
@property
def element_type(self) -> PsType:
"""Element type of this buffer"""
return self._element_type
def __repr__(self) -> str:
return f"PsBuffer({self._name}: {self.element_type}[{len(self.shape)}D])"
from .platform import Platform
from .generic_cpu import GenericCpu, GenericVectorCpu
from .generic_gpu import GenericGpu
from .cuda import CudaPlatform
from .x86 import X86VectorCpu, X86VectorArch
from .sycl import SyclPlatform
__all__ = [
"Platform",
"GenericCpu",
"GenericVectorCpu",
"X86VectorCpu",
"X86VectorArch",
"GenericGpu",
"CudaPlatform",
"SyclPlatform",
]
from __future__ import annotations
from warnings import warn
from typing import TYPE_CHECKING
from ...types import constify
from ..exceptions import MaterializationError
from .generic_gpu import GenericGpu
from ..kernelcreation import (
Typifier,
IterationSpace,
FullIterationSpace,
SparseIterationSpace,
AstFactory,
)
from ..kernelcreation.context import KernelCreationContext
from ..ast.structural import PsBlock, PsConditional, PsDeclaration
from ..ast.expressions import (
PsExpression,
PsLiteralExpr,
PsCast,
PsCall,
PsLookup,
PsBufferAcc,
)
from ..ast.expressions import PsLt, PsAnd
from ...types import PsSignedIntegerType, PsIeeeFloatType
from ..literals import PsLiteral
from ..functions import PsMathFunction, MathFunctions, CFunction
if TYPE_CHECKING:
from ...codegen import GpuThreadsRange
int32 = PsSignedIntegerType(width=32, const=False)
BLOCK_IDX = [
PsLiteralExpr(PsLiteral(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z")
]
THREAD_IDX = [
PsLiteralExpr(PsLiteral(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z")
]
BLOCK_DIM = [
PsLiteralExpr(PsLiteral(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z")
]
GRID_DIM = [
PsLiteralExpr(PsLiteral(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z")
]
class CudaPlatform(GenericGpu):
"""Platform for CUDA-based GPUs."""
def __init__(
self, ctx: KernelCreationContext,
omit_range_check: bool = False,
manual_launch_grid: bool = False,
) -> None:
super().__init__(ctx)
self._omit_range_check = omit_range_check
self._manual_launch_grid = manual_launch_grid
self._typify = Typifier(ctx)
@property
def required_headers(self) -> set[str]:
return {'"gpu_defines.h"'}
def materialize_iteration_space(
self, body: PsBlock, ispace: IterationSpace
) -> tuple[PsBlock, GpuThreadsRange | None]:
if isinstance(ispace, FullIterationSpace):
return self._prepend_dense_translation(body, ispace)
elif isinstance(ispace, SparseIterationSpace):
return self._prepend_sparse_translation(body, ispace)
else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def select_function(self, call: PsCall) -> PsExpression:
assert isinstance(call.function, PsMathFunction)
func = call.function.func
dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args
if isinstance(dtype, PsIeeeFloatType):
match func:
case (
MathFunctions.Exp
| MathFunctions.Log
| MathFunctions.Sin
| MathFunctions.Cos
| MathFunctions.Sqrt
| MathFunctions.Ceil
| MathFunctions.Floor
) if dtype.width in (16, 32, 64):
prefix = "h" if dtype.width == 16 else ""
suffix = "f" if dtype.width == 32 else ""
name = f"{prefix}{func.function_name}{suffix}"
cfunc = CFunction(name, arg_types, dtype)
case (
MathFunctions.Pow
| MathFunctions.Tan
| MathFunctions.Sinh
| MathFunctions.Cosh
| MathFunctions.ASin
| MathFunctions.ACos
| MathFunctions.ATan
| MathFunctions.ATan2
) if dtype.width in (32, 64):
# These are unavailable for fp16
suffix = "f" if dtype.width == 32 else ""
name = f"{func.function_name}{suffix}"
cfunc = CFunction(name, arg_types, dtype)
case (
MathFunctions.Min | MathFunctions.Max | MathFunctions.Abs
) if dtype.width in (32, 64):
suffix = "f" if dtype.width == 32 else ""
name = f"f{func.function_name}{suffix}"
cfunc = CFunction(name, arg_types, dtype)
case MathFunctions.Abs if dtype.width == 16:
cfunc = CFunction(" __habs", arg_types, dtype)
case _:
raise MaterializationError(
f"Cannot materialize call to function {func}"
)
call.function = cfunc
return call
raise MaterializationError(
f"No implementation available for function {func} on data type {dtype}"
)
# Internals
def _prepend_dense_translation(
self, body: PsBlock, ispace: FullIterationSpace
) -> tuple[PsBlock, GpuThreadsRange | None]:
dimensions = ispace.dimensions_in_loop_order()
if not self._manual_launch_grid:
try:
threads_range = self.threads_from_ispace(ispace)
except MaterializationError as e:
warn(
str(e.args[0])
+ "\nIf this is intended, set `manual_launch_grid=True` in the code generator configuration.",
UserWarning,
)
threads_range = None
else:
threads_range = None
indexing_decls = []
conds = []
for i, dim in enumerate(dimensions[::-1]):
dim.counter.dtype = constify(dim.counter.get_dtype())
ctr = PsExpression.make(dim.counter)
indexing_decls.append(
self._typify(
PsDeclaration(
ctr,
dim.start
+ dim.step
* PsCast(ctr.get_dtype(), self._linear_thread_idx(i)),
)
)
)
if not self._omit_range_check:
conds.append(PsLt(ctr, dim.stop))
indexing_decls = indexing_decls[::-1]
if conds:
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
ast = PsBlock(indexing_decls + [PsConditional(condition, body)])
else:
body.statements = indexing_decls + body.statements
ast = body
return ast, threads_range
def _prepend_sparse_translation(
self, body: PsBlock, ispace: SparseIterationSpace
) -> tuple[PsBlock, GpuThreadsRange]:
factory = AstFactory(self._ctx)
ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
sparse_ctr = PsExpression.make(ispace.sparse_counter)
thread_idx = self._linear_thread_idx(0)
sparse_idx_decl = self._typify(
PsDeclaration(sparse_ctr, PsCast(sparse_ctr.get_dtype(), thread_idx))
)
mappings = [
PsDeclaration(
PsExpression.make(ctr),
PsLookup(
PsBufferAcc(
ispace.index_list.base_pointer,
(sparse_ctr, factory.parse_index(0)),
),
coord.name,
),
)
for ctr, coord in zip(ispace.spatial_indices, ispace.coordinate_members)
]
body.statements = mappings + body.statements
if not self._omit_range_check:
stop = PsExpression.make(ispace.index_list.shape[0])
condition = PsLt(sparse_ctr, stop)
ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)])
else:
body.statements = [sparse_idx_decl] + body.statements
ast = body
return ast, self.threads_from_ispace(ispace)
def _linear_thread_idx(self, coord: int):
block_size = BLOCK_DIM[coord]
block_idx = BLOCK_IDX[coord]
thread_idx = THREAD_IDX[coord]
return block_idx * block_size + thread_idx
from abc import ABC, abstractmethod
from typing import Sequence
from pystencils.backend.ast.expressions import PsCall
from ..functions import CFunction, PsMathFunction, MathFunctions
from ...types import PsIntegerType, PsIeeeFloatType
from .platform import Platform
from ..exceptions import MaterializationError
from ..kernelcreation import AstFactory
from ..kernelcreation.iteration_space import (
IterationSpace,
FullIterationSpace,
SparseIterationSpace,
)
from ..constants import PsConstant
from ..ast.structural import PsDeclaration, PsLoop, PsBlock
from ..ast.expressions import (
PsSymbolExpr,
PsExpression,
PsBufferAcc,
PsLookup,
PsGe,
PsLe,
PsTernary,
)
from ..ast.vector import PsVecMemAcc
from ...types import PsVectorType, PsCustomType
class GenericCpu(Platform):
"""Generic CPU platform.
The `GenericCPU` platform models the following execution environment:
- Generic multicore CPU architecture
- Iteration space represented by a loop nest, kernels are executed as a whole
- C standard library math functions available (``#include <math.h>`` or ``#include <cmath>``)
"""
@property
def required_headers(self) -> set[str]:
return {"<cmath>"}
def materialize_iteration_space(
self, body: PsBlock, ispace: IterationSpace
) -> PsBlock:
if isinstance(ispace, FullIterationSpace):
return self._create_domain_loops(body, ispace)
elif isinstance(ispace, SparseIterationSpace):
return self._create_sparse_loop(body, ispace)
else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def select_function(self, call: PsCall) -> PsExpression:
assert isinstance(call.function, PsMathFunction)
func = call.function.func
dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args
if isinstance(dtype, PsIeeeFloatType) and dtype.width in (32, 64):
cfunc: CFunction
match func:
case (
MathFunctions.Exp
| MathFunctions.Log
| MathFunctions.Sin
| MathFunctions.Cos
| MathFunctions.Tan
| MathFunctions.Sinh
| MathFunctions.Cosh
| MathFunctions.ASin
| MathFunctions.ACos
| MathFunctions.ATan
| MathFunctions.ATan2
| MathFunctions.Pow
| MathFunctions.Sqrt
| MathFunctions.Floor
| MathFunctions.Ceil
):
cfunc = CFunction(func.function_name, arg_types, dtype)
case MathFunctions.Abs | MathFunctions.Min | MathFunctions.Max:
cfunc = CFunction("f" + func.function_name, arg_types, dtype)
call.function = cfunc
return call
if isinstance(dtype, PsIntegerType):
match func:
case MathFunctions.Abs:
zero = PsExpression.make(PsConstant(0, dtype))
arg = call.args[0]
return PsTernary(PsGe(arg, zero), arg, -arg)
case MathFunctions.Min:
arg1, arg2 = call.args
return PsTernary(PsLe(arg1, arg2), arg1, arg2)
case MathFunctions.Max:
arg1, arg2 = call.args
return PsTernary(PsGe(arg1, arg2), arg1, arg2)
raise MaterializationError(
f"No implementation available for function {func} on data type {dtype}"
)
# Internals
def _create_domain_loops(
self, body: PsBlock, ispace: FullIterationSpace
) -> PsBlock:
factory = AstFactory(self._ctx)
# Determine loop order by permuting dimensions
archetype_field = ispace.archetype_field
if archetype_field is not None:
loop_order = archetype_field.layout
else:
loop_order = None
loops = factory.loops_from_ispace(ispace, body, loop_order)
return PsBlock([loops])
def _create_sparse_loop(self, body: PsBlock, ispace: SparseIterationSpace):
factory = AstFactory(self._ctx)
mappings = [
PsDeclaration(
PsSymbolExpr(ctr),
PsLookup(
PsBufferAcc(
ispace.index_list.base_pointer,
(
PsExpression.make(ispace.sparse_counter),
factory.parse_index(0),
),
),
coord.name,
),
)
for ctr, coord in zip(ispace.spatial_indices, ispace.coordinate_members)
]
body = PsBlock(mappings + body.statements)
loop = PsLoop(
PsSymbolExpr(ispace.sparse_counter),
PsExpression.make(PsConstant(0, self._ctx.index_dtype)),
PsExpression.make(ispace.index_list.shape[0]),
PsExpression.make(PsConstant(1, self._ctx.index_dtype)),
body,
)
return PsBlock([loop])
class GenericVectorCpu(GenericCpu, ABC):
"""Base class for CPU platforms with vectorization support through intrinsics."""
@abstractmethod
def type_intrinsic(self, vector_type: PsVectorType) -> PsCustomType:
"""Return the intrinsic vector type for the given generic vector type,
or raise a `MaterializationError` if type is not supported."""
@abstractmethod
def constant_intrinsic(self, c: PsConstant) -> PsExpression:
"""Return an expression that initializes a constant vector,
or raise a `MaterializationError` if not supported."""
@abstractmethod
def op_intrinsic(
self, expr: PsExpression, operands: Sequence[PsExpression]
) -> PsExpression:
"""Return an expression intrinsically invoking the given operation
or raise a `MaterializationError` if not supported."""
@abstractmethod
def math_func_intrinsic(
self, expr: PsCall, operands: Sequence[PsExpression]
) -> PsExpression:
"""Return an expression intrinsically invoking the given mathematical
function or raise a `MaterializationError` if not supported."""
@abstractmethod
def vector_load(self, acc: PsVecMemAcc) -> PsExpression:
"""Return an expression intrinsically performing a vector load,
or raise a `MaterializationError` if not supported."""
@abstractmethod
def vector_store(self, acc: PsVecMemAcc, arg: PsExpression) -> PsExpression:
"""Return an expression intrinsically performing a vector store,
or raise a `MaterializationError` if not supported."""
from __future__ import annotations
from typing import TYPE_CHECKING
from abc import abstractmethod
from ..ast.expressions import PsExpression
from ..ast.structural import PsBlock
from ..kernelcreation.iteration_space import (
IterationSpace,
FullIterationSpace,
SparseIterationSpace,
)
from .platform import Platform
from ..exceptions import MaterializationError
if TYPE_CHECKING:
from ...codegen.kernel import GpuThreadsRange
class GenericGpu(Platform):
@abstractmethod
def materialize_iteration_space(
self, body: PsBlock, ispace: IterationSpace
) -> tuple[PsBlock, GpuThreadsRange | None]:
pass
@classmethod
def threads_from_ispace(cls, ispace: IterationSpace) -> GpuThreadsRange:
from ...codegen.kernel import GpuThreadsRange
if isinstance(ispace, FullIterationSpace):
return cls._threads_from_full_ispace(ispace)
elif isinstance(ispace, SparseIterationSpace):
work_items = (PsExpression.make(ispace.index_list.shape[0]),)
return GpuThreadsRange(work_items)
else:
assert False
@classmethod
def _threads_from_full_ispace(cls, ispace: FullIterationSpace) -> GpuThreadsRange:
from ...codegen.kernel import GpuThreadsRange
dimensions = ispace.dimensions_in_loop_order()[::-1]
if len(dimensions) > 3:
raise NotImplementedError(
f"Cannot create a GPU threads range for an {len(dimensions)}-dimensional iteration space"
)
from ..ast.analysis import collect_undefined_symbols as collect
for dim in dimensions:
symbs = collect(dim.start) | collect(dim.stop) | collect(dim.step)
for ctr in ispace.counters:
if ctr in symbs:
raise MaterializationError(
"Unable to construct GPU threads range for iteration space: "
f"Limits of dimension counter {dim.counter.name} "
f"depend on another dimension's counter {ctr.name}"
)
work_items = [ispace.actual_iterations(dim) for dim in dimensions]
return GpuThreadsRange(work_items)
from abc import ABC, abstractmethod
from typing import Any
from ..ast.structural import PsBlock
from ..ast.expressions import PsCall, PsExpression
from ..kernelcreation.context import KernelCreationContext
from ..kernelcreation.iteration_space import IterationSpace
class Platform(ABC):
"""Abstract base class for all supported platforms.
The platform performs all target-dependent tasks during code generation:
- Translation of the iteration space to an index source (loop nest, GPU indexing, ...)
- Platform-specific optimizations (e.g. vectorization, OpenMP)
"""
def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx
@property
@abstractmethod
def required_headers(self) -> set[str]:
pass
@abstractmethod
def materialize_iteration_space(
self, body: PsBlock, ispace: IterationSpace
) -> PsBlock | tuple[PsBlock, Any]:
pass
@abstractmethod
def select_function(
self, call: PsCall
) -> PsExpression:
"""Select an implementation for the given function on the given data type.
If no viable implementation exists, raise a `MaterializationError`.
"""
pass
from __future__ import annotations
from typing import TYPE_CHECKING
from ..functions import CFunction, PsMathFunction, MathFunctions
from ..kernelcreation.iteration_space import (
IterationSpace,
FullIterationSpace,
SparseIterationSpace,
)
from ..ast.structural import PsDeclaration, PsBlock, PsConditional
from ..ast.expressions import (
PsExpression,
PsSymbolExpr,
PsSubscript,
PsLt,
PsAnd,
PsCall,
PsGe,
PsLe,
PsTernary,
PsLookup,
PsBufferAcc,
)
from ..extensions.cpp import CppMethodCall
from ..kernelcreation import KernelCreationContext, AstFactory
from ..constants import PsConstant
from .generic_gpu import GenericGpu
from ..exceptions import MaterializationError
from ...types import PsCustomType, PsIeeeFloatType, constify, PsIntegerType
if TYPE_CHECKING:
from ...codegen import GpuThreadsRange
class SyclPlatform(GenericGpu):
def __init__(
self,
ctx: KernelCreationContext,
omit_range_check: bool = False,
automatic_block_size: bool = False
):
super().__init__(ctx)
self._omit_range_check = omit_range_check
self._automatic_block_size = automatic_block_size
@property
def required_headers(self) -> set[str]:
return {"<sycl/sycl.hpp>"}
def materialize_iteration_space(
self, body: PsBlock, ispace: IterationSpace
) -> tuple[PsBlock, GpuThreadsRange]:
if isinstance(ispace, FullIterationSpace):
return self._prepend_dense_translation(body, ispace)
elif isinstance(ispace, SparseIterationSpace):
return self._prepend_sparse_translation(body, ispace)
else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def select_function(self, call: PsCall) -> PsExpression:
assert isinstance(call.function, PsMathFunction)
func = call.function.func
dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args
if isinstance(dtype, PsIeeeFloatType) and dtype.width in (16, 32, 64):
match func:
case (
MathFunctions.Exp
| MathFunctions.Log
| MathFunctions.Sin
| MathFunctions.Cos
| MathFunctions.Tan
| MathFunctions.Sinh
| MathFunctions.Cosh
| MathFunctions.ASin
| MathFunctions.ACos
| MathFunctions.ATan
| MathFunctions.ATan2
| MathFunctions.Pow
| MathFunctions.Sqrt
| MathFunctions.Floor
| MathFunctions.Ceil
):
cfunc = CFunction(f"sycl::{func.function_name}", arg_types, dtype)
case MathFunctions.Abs | MathFunctions.Min | MathFunctions.Max:
cfunc = CFunction(f"sycl::f{func.function_name}", arg_types, dtype)
call.function = cfunc
return call
if isinstance(dtype, PsIntegerType):
match func:
case MathFunctions.Abs:
zero = PsExpression.make(PsConstant(0, dtype))
arg = call.args[0]
return PsTernary(PsGe(arg, zero), arg, -arg)
case MathFunctions.Min:
arg1, arg2 = call.args
return PsTernary(PsLe(arg1, arg2), arg1, arg2)
case MathFunctions.Max:
arg1, arg2 = call.args
return PsTernary(PsGe(arg1, arg2), arg1, arg2)
raise MaterializationError(
f"No implementation available for function {func} on data type {dtype}"
)
def _prepend_dense_translation(
self, body: PsBlock, ispace: FullIterationSpace
) -> tuple[PsBlock, GpuThreadsRange]:
rank = ispace.rank
id_type = self._id_type(rank)
id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type))
id_decl = self._id_declaration(rank, id_symbol)
dimensions = ispace.dimensions_in_loop_order()
launch_config = self.threads_from_ispace(ispace)
indexing_decls = [id_decl]
conds = []
# Other than in CUDA, SYCL ids are linearized in C order
# The leftmost entry of an ID varies slowest, and the rightmost entry varies fastest
# See https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#sec:multi-dim-linearization
for i, dim in enumerate(dimensions):
# Slowest to fastest
coord = PsExpression.make(PsConstant(i, self._ctx.index_dtype))
work_item_idx = PsSubscript(id_symbol, (coord,))
dim.counter.dtype = constify(dim.counter.get_dtype())
work_item_idx.dtype = dim.counter.get_dtype()
ctr = PsExpression.make(dim.counter)
indexing_decls.append(
PsDeclaration(ctr, dim.start + work_item_idx * dim.step)
)
if not self._omit_range_check:
conds.append(PsLt(ctr, dim.stop))
if conds:
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
ast = PsBlock(indexing_decls + [PsConditional(condition, body)])
else:
body.statements = indexing_decls + body.statements
ast = body
return ast, launch_config
def _prepend_sparse_translation(
self, body: PsBlock, ispace: SparseIterationSpace
) -> tuple[PsBlock, GpuThreadsRange]:
factory = AstFactory(self._ctx)
id_type = PsCustomType("sycl::id< 1 >", const=True)
id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type))
zero = PsExpression.make(PsConstant(0, self._ctx.index_dtype))
subscript = PsSubscript(id_symbol, (zero,))
ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
subscript.dtype = ispace.sparse_counter.get_dtype()
sparse_ctr = PsExpression.make(ispace.sparse_counter)
sparse_idx_decl = PsDeclaration(sparse_ctr, subscript)
mappings = [
PsDeclaration(
PsExpression.make(ctr),
PsLookup(
PsBufferAcc(
ispace.index_list.base_pointer,
(sparse_ctr, factory.parse_index(0)),
),
coord.name,
),
)
for ctr, coord in zip(ispace.spatial_indices, ispace.coordinate_members)
]
body.statements = mappings + body.statements
if not self._omit_range_check:
stop = PsExpression.make(ispace.index_list.shape[0])
condition = PsLt(sparse_ctr, stop)
ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)])
else:
body.statements = [sparse_idx_decl] + body.statements
ast = body
return ast, self.threads_from_ispace(ispace)
def _item_type(self, rank: int):
if not self._automatic_block_size:
return PsCustomType(f"sycl::nd_item< {rank} >", const=True)
else:
return PsCustomType(f"sycl::item< {rank} >", const=True)
def _id_type(self, rank: int):
return PsCustomType(f"sycl::id< {rank} >", const=True)
def _id_declaration(self, rank: int, id: PsSymbolExpr) -> PsDeclaration:
item_type = self._item_type(rank)
item = PsExpression.make(self._ctx.get_symbol("sycl_item", item_type))
if not self._automatic_block_size:
rhs = CppMethodCall(item, "get_global_id", self._id_type(rank))
else:
rhs = CppMethodCall(item, "get_id", self._id_type(rank))
return PsDeclaration(id, rhs)
from __future__ import annotations
from typing import Sequence
from enum import Enum
from functools import cache
from ..ast.expressions import (
PsExpression,
PsAddressOf,
PsMemAcc,
PsUnOp,
PsBinOp,
PsAdd,
PsSub,
PsMul,
PsDiv,
PsConstantExpr,
PsCast,
PsCall,
)
from ..ast.vector import PsVecMemAcc, PsVecBroadcast
from ...types import PsCustomType, PsVectorType, PsPointerType
from ..constants import PsConstant
from ..exceptions import MaterializationError
from .generic_cpu import GenericVectorCpu
from ..kernelcreation import KernelCreationContext
from ...types.quick import Fp, UInt, SInt
from ..functions import CFunction, PsMathFunction, MathFunctions
class X86VectorArch(Enum):
SSE = 128
AVX = 256
AVX512 = 512
AVX512_FP16 = AVX512 + 1 # TODO improve modelling?
def __ge__(self, other: X86VectorArch) -> bool:
return self.value >= other.value
def __gt__(self, other: X86VectorArch) -> bool:
return self.value > other.value
def __str__(self) -> str:
return self.name
@property
def max_vector_width(self) -> int:
return self.value
def intrin_prefix(self, vtype: PsVectorType) -> str:
match vtype.width:
case 128 if self >= X86VectorArch.SSE:
prefix = "_mm"
case 256 if self >= X86VectorArch.AVX:
prefix = "_mm256"
case 512 if self >= X86VectorArch.AVX512:
prefix = "_mm512"
case other:
raise MaterializationError(
f"x86/{self} does not support vector width {other}"
)
return prefix
def intrin_suffix(self, vtype: PsVectorType) -> str:
scalar_type = vtype.scalar_type
match scalar_type:
case Fp(16) if self >= X86VectorArch.AVX512_FP16:
suffix = "ph"
case Fp(32):
suffix = "ps"
case Fp(64):
suffix = "pd"
case SInt(width):
suffix = f"epi{width}"
case _:
raise MaterializationError(
f"x86/{self} does not support scalar type {scalar_type}"
)
return suffix
def intrin_type(self, vtype: PsVectorType):
scalar_type = vtype.scalar_type
match scalar_type:
case Fp(16) if self >= X86VectorArch.AVX512:
suffix = "h"
case Fp(32):
suffix = ""
case Fp(64):
suffix = "d"
case SInt(_):
suffix = "i"
case _:
raise MaterializationError(
f"x86/{self} does not support scalar type {scalar_type}"
)
if vtype.width > self.max_vector_width:
raise MaterializationError(f"x86/{self} does not support {vtype}")
return PsCustomType(f"__m{vtype.width}{suffix}")
class X86VectorCpu(GenericVectorCpu):
"""Platform modelling the X86 SSE/AVX/AVX512 vector architectures.
All intrinsics information is extracted from
https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html.
"""
def __init__(self, ctx: KernelCreationContext, vector_arch: X86VectorArch):
super().__init__(ctx)
self._vector_arch = vector_arch
@property
def vector_arch(self) -> X86VectorArch:
return self._vector_arch
@property
def required_headers(self) -> set[str]:
if self._vector_arch == X86VectorArch.SSE:
headers = {
"<immintrin.h>",
"<xmmintrin.h>",
"<emmintrin.h>",
"<pmmintrin.h>",
"<tmmintrin.h>",
"<smmintrin.h>",
"<nmmintrin.h>",
}
else:
headers = {"<immintrin.h>"}
return super().required_headers | headers
def type_intrinsic(self, vector_type: PsVectorType) -> PsCustomType:
return self._vector_arch.intrin_type(vector_type)
def constant_intrinsic(self, c: PsConstant) -> PsExpression:
vtype = c.dtype
assert isinstance(vtype, PsVectorType)
stype = vtype.scalar_type
prefix = self._vector_arch.intrin_prefix(vtype)
suffix = self._vector_arch.intrin_suffix(vtype)
if stype == SInt(64) and vtype.vector_entries <= 4:
suffix += "x"
set_func = CFunction(
f"{prefix}_set_{suffix}", (stype,) * vtype.vector_entries, vtype
)
values = [PsConstantExpr(PsConstant(v, stype)) for v in c.value]
return set_func(*values)
def op_intrinsic(
self, expr: PsExpression, operands: Sequence[PsExpression]
) -> PsExpression:
match expr:
case PsUnOp() | PsBinOp():
func = _x86_op_intrin(self._vector_arch, expr, expr.get_dtype())
intrinsic = func(*operands)
intrinsic.dtype = func.return_type
return intrinsic
case _:
raise MaterializationError(f"Cannot map {type(expr)} to x86 intrinsic")
def math_func_intrinsic(
self, expr: PsCall, operands: Sequence[PsExpression]
) -> PsExpression:
assert isinstance(expr.function, PsMathFunction)
vtype = expr.get_dtype()
assert isinstance(vtype, PsVectorType)
prefix = self._vector_arch.intrin_prefix(vtype)
suffix = self._vector_arch.intrin_suffix(vtype)
rtype = atype = self._vector_arch.intrin_type(vtype)
match expr.function.func:
case (
MathFunctions.Exp
| MathFunctions.Log
| MathFunctions.Sin
| MathFunctions.Cos
| MathFunctions.Tan
| MathFunctions.Sinh
| MathFunctions.Cosh
| MathFunctions.ASin
| MathFunctions.ACos
| MathFunctions.ATan
| MathFunctions.ATan2
| MathFunctions.Pow
):
raise MaterializationError(
"Trigonometry, exp, log, and pow require SVML."
)
case MathFunctions.Floor | MathFunctions.Ceil if vtype.is_float():
opstr = expr.function.func.function_name
if vtype.width > 256:
raise MaterializationError("512bit ceil/floor require SVML.")
case MathFunctions.Sqrt if vtype.is_float():
opstr = expr.function.name
case MathFunctions.Min | MathFunctions.Max:
opstr = expr.function.func.function_name
if (
vtype.is_int()
and vtype.scalar_type.width == 64
and self._vector_arch < X86VectorArch.AVX512
):
raise MaterializationError(
"64bit integer (signed and unsigned) min/max intrinsics require AVX512."
)
case MathFunctions.Abs:
assert len(operands) == 1, "abs takes exactly one argument."
op = operands[0]
match vtype.scalar_type:
case UInt():
return op
case SInt(width):
opstr = expr.function.func.function_name
if width == 64 and self._vector_arch < X86VectorArch.AVX512:
raise MaterializationError(
"64bit integer abs intrinsic requires AVX512."
)
case Fp():
neg_zero = self.constant_intrinsic(PsConstant(-0.0, vtype))
opstr = "andnot"
func = CFunction(
f"{prefix}_{opstr}_{suffix}", (atype,) * 2, rtype
)
return func(neg_zero, op)
case _:
raise MaterializationError(
f"x86/{self} does not support {expr.function.func.function_name} on type {vtype}."
)
if expr.function.func in [
MathFunctions.ATan2,
MathFunctions.Min,
MathFunctions.Max,
]:
num_args = 2
else:
num_args = 1
func = CFunction(f"{prefix}_{opstr}_{suffix}", (atype,) * num_args, rtype)
return func(*operands)
def vector_load(self, acc: PsVecMemAcc) -> PsExpression:
if acc.stride is None:
load_func, addr_type = _x86_packed_load(self._vector_arch, acc.dtype, False)
addr: PsExpression = PsAddressOf(PsMemAcc(acc.pointer, acc.offset))
if addr_type:
addr = PsCast(addr_type, addr)
intrinsic = load_func(addr)
intrinsic.dtype = load_func.return_type
return intrinsic
else:
raise NotImplementedError("Gather loads not implemented yet.")
def vector_store(self, acc: PsVecMemAcc, arg: PsExpression) -> PsExpression:
if acc.stride is None:
store_func, addr_type = _x86_packed_store(
self._vector_arch, acc.dtype, False
)
addr: PsExpression = PsAddressOf(PsMemAcc(acc.pointer, acc.offset))
if addr_type:
addr = PsCast(addr_type, addr)
intrinsic = store_func(addr, arg)
intrinsic.dtype = store_func.return_type
return intrinsic
else:
raise NotImplementedError("Scatter stores not implemented yet.")
@cache
def _x86_packed_load(
varch: X86VectorArch, vtype: PsVectorType, aligned: bool
) -> tuple[CFunction, PsPointerType | None]:
prefix = varch.intrin_prefix(vtype)
ptr_type = PsPointerType(vtype.scalar_type, const=True)
if isinstance(vtype.scalar_type, SInt):
suffix = f"si{vtype.width}"
addr_type = PsPointerType(varch.intrin_type(vtype))
else:
suffix = varch.intrin_suffix(vtype)
addr_type = None
return (
CFunction(
f"{prefix}_load{'' if aligned else 'u'}_{suffix}", (ptr_type,), vtype
),
addr_type,
)
@cache
def _x86_packed_store(
varch: X86VectorArch, vtype: PsVectorType, aligned: bool
) -> tuple[CFunction, PsPointerType | None]:
prefix = varch.intrin_prefix(vtype)
ptr_type = PsPointerType(vtype.scalar_type, const=True)
if isinstance(vtype.scalar_type, SInt):
suffix = f"si{vtype.width}"
addr_type = PsPointerType(varch.intrin_type(vtype))
else:
suffix = varch.intrin_suffix(vtype)
addr_type = None
return (
CFunction(
f"{prefix}_store{'' if aligned else 'u'}_{suffix}",
(ptr_type, vtype),
PsCustomType("void"),
),
addr_type,
)
@cache
def _x86_op_intrin(
varch: X86VectorArch, op: PsUnOp | PsBinOp, vtype: PsVectorType
) -> CFunction:
prefix = varch.intrin_prefix(vtype)
suffix = varch.intrin_suffix(vtype)
rtype = atype = varch.intrin_type(vtype)
match op:
case PsVecBroadcast():
opstr = "set1"
if vtype.scalar_type == SInt(64) and vtype.vector_entries <= 4:
suffix += "x"
atype = vtype.scalar_type
case PsAdd():
opstr = "add"
case PsSub():
opstr = "sub"
case PsMul() if vtype.is_int():
raise MaterializationError(
f"Unable to select intrinsic for integer multiplication: "
f"{varch.name} does not support packed integer multiplication.\n"
f" at: {op}"
)
case PsMul():
opstr = "mul"
case PsDiv():
opstr = "div"
case PsCast(target_type, arg):
atype = arg.dtype
widest_type = max(vtype, atype, key=lambda t: t.width)
assert target_type == vtype, "type mismatch"
assert isinstance(atype, PsVectorType)
def panic(detail: str = ""):
raise MaterializationError(
f"Unable to select intrinsic for type conversion: "
f"{varch.name} does not support packed conversion from {atype} to {target_type}. {detail}\n"
f" at: {op}"
)
if atype == vtype:
panic("Use `EliminateConstants` to eliminate trivial casts.")
match (atype.scalar_type, vtype.scalar_type):
# Not supported: cvtepi8_pX, cvtpX_epi8, cvtepi16_p[sd], cvtp[sd]_epi16
case (
(SInt(8), Fp())
| (Fp(), SInt(8))
| (SInt(16), Fp(32))
| (SInt(16), Fp(64))
| (Fp(32), SInt(16))
| (Fp(64), SInt(16))
):
panic()
# AVX512 only: cvtepi64_pX, cvtpX_epi64
case (SInt(64), Fp()) | (
Fp(),
SInt(64),
) if varch < X86VectorArch.AVX512:
panic()
# AVX512 only: cvtepiA_epiT if A > T
case (SInt(a), SInt(t)) if a > t and varch < X86VectorArch.AVX512:
panic()
case _:
prefix = varch.intrin_prefix(widest_type)
opstr = f"cvt{varch.intrin_suffix(atype)}"
case _:
raise MaterializationError(
f"Unable to select operation intrinsic for {type(op)}"
)
num_args = 1 if isinstance(op, PsUnOp) else 2
return CFunction(f"{prefix}_{opstr}_{suffix}", (atype,) * num_args, rtype)
"""
This module contains various transformation and optimization passes that can be
executed on the backend AST.
Canonical Form
==============
Many transformations in this module require that their input AST is in *canonical form*.
This means that:
- Each symbol, constant, and expression node is annotated with a data type;
- Each symbol has at most one declaration;
- Each symbol that is never written to apart from its declaration has a ``const`` type; and
- Each symbol whose type is *not* ``const`` has at least one non-declaring assignment.
The first requirement can be ensured by running the `Typifier` on each newly constructed subtree.
The other three requirements are ensured by the `CanonicalizeSymbols` pass,
which should be run first before applying any optimizing transformations.
All transformations in this module retain canonicality of the AST.
Canonicality allows transformations to forego various checks that would otherwise be necessary
to prove their legality.
Certain transformations, like the `LoopVectorizer`, state additional requirements, e.g.
the absence of loop-carried dependencies.
Transformations
===============
Canonicalization
----------------
.. autoclass:: CanonicalizeSymbols
:members: __call__
AST Cloning
-----------
.. autoclass:: CanonicalClone
:members: __call__
Simplifying Transformations
---------------------------
.. autoclass:: EliminateConstants
:members: __call__
.. autoclass:: EliminateBranches
:members: __call__
Code Rewriting
--------------
.. autofunction:: substitute_symbols
Code Motion
-----------
.. autoclass:: HoistLoopInvariantDeclarations
:members: __call__
Loop Reshaping Transformations
------------------------------
.. autoclass:: ReshapeLoops
:members:
.. autoclass:: InsertPragmasAtLoops
:members:
.. autoclass:: AddOpenMP
:members:
Vectorization
-------------
.. autoclass:: VectorizationAxis
:members:
.. autoclass:: VectorizationContext
:members:
.. autoclass:: AstVectorizer
:members:
.. autoclass:: LoopVectorizer
:members:
Code Lowering and Materialization
---------------------------------
.. autoclass:: LowerToC
:members: __call__
.. autoclass:: SelectFunctions
:members: __call__
.. autoclass:: SelectIntrinsics
:members:
"""
from .canonicalize_symbols import CanonicalizeSymbols
from .canonical_clone import CanonicalClone
from .rewrite import substitute_symbols
from .eliminate_constants import EliminateConstants
from .eliminate_branches import EliminateBranches
from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations
from .reshape_loops import ReshapeLoops
from .add_pragmas import InsertPragmasAtLoops, LoopPragma, AddOpenMP
from .ast_vectorizer import VectorizationAxis, VectorizationContext, AstVectorizer
from .loop_vectorizer import LoopVectorizer
from .lower_to_c import LowerToC
from .select_functions import SelectFunctions
from .select_intrinsics import SelectIntrinsics
__all__ = [
"CanonicalizeSymbols",
"CanonicalClone",
"substitute_symbols",
"EliminateConstants",
"EliminateBranches",
"HoistLoopInvariantDeclarations",
"ReshapeLoops",
"InsertPragmasAtLoops",
"LoopPragma",
"AddOpenMP",
"VectorizationAxis",
"VectorizationContext",
"AstVectorizer",
"LoopVectorizer",
"LowerToC",
"SelectFunctions",
"SelectIntrinsics",
]
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
from collections import defaultdict
from ..kernelcreation import KernelCreationContext
from ..ast import PsAstNode
from ..ast.structural import PsBlock, PsLoop, PsPragma
from ..ast.expressions import PsExpression
__all__ = ["InsertPragmasAtLoops", "LoopPragma", "AddOpenMP"]
@dataclass
class LoopPragma:
"""A pragma that should be prepended to loops at a certain nesting depth."""
text: str
"""The pragma text, without the ``#pragma ``"""
loop_nesting_depth: int
"""Nesting depth of the loops the pragma should be added to. ``-1`` indicates the innermost loops."""
def __post_init__(self):
if self.loop_nesting_depth < -1:
raise ValueError("Loop nesting depth must be nonnegative or -1.")
@dataclass
class Nesting:
depth: int
has_inner_loops: bool = False
class InsertPragmasAtLoops:
"""Insert pragmas before loops in a loop nest.
This transformation augments the AST with pragma directives which are prepended to loops.
The directives are annotated with the nesting depth of the loops they should be added to,
where ``-1`` indicates the innermost loop.
The relative order of pragmas with the (exact) same nesting depth is preserved;
however, no guarantees are given about the relative order of pragmas inserted at ``-1``
and at the actual depth of the innermost loop.
"""
def __init__(
self, ctx: KernelCreationContext, insertions: Sequence[LoopPragma]
) -> None:
self._ctx = ctx
self._insertions: dict[int, list[LoopPragma]] = defaultdict(list)
for ins in insertions:
self._insertions[ins.loop_nesting_depth].append(ins)
def __call__(self, node: PsAstNode) -> PsAstNode:
is_loop = isinstance(node, PsLoop)
if is_loop:
node = PsBlock([node])
self.visit(node, Nesting(0))
if is_loop and len(node.children) == 1:
node = node.children[0]
return node
def visit(self, node: PsAstNode, nest: Nesting) -> None:
match node:
case PsExpression():
return
case PsBlock(children):
new_children: list[PsAstNode] = []
for c in children:
if isinstance(c, PsLoop):
nest.has_inner_loops = True
inner_nest = Nesting(nest.depth + 1)
self.visit(c.body, inner_nest)
if not inner_nest.has_inner_loops:
# c is the innermost loop
for pragma in self._insertions[-1]:
new_children.append(PsPragma(pragma.text))
for pragma in self._insertions[nest.depth]:
new_children.append(PsPragma(pragma.text))
new_children.append(c)
node.children = new_children
case other:
for c in other.children:
self.visit(c, nest)
class AddOpenMP:
"""Apply OpenMP directives to loop nests.
This transformation augments the AST with OpenMP pragmas according to the given configuration.
"""
def __init__(
self,
ctx: KernelCreationContext,
nesting_depth: int = 0,
num_threads: int | None = None,
schedule: str | None = None,
collapse: int | None = None,
omit_parallel: bool = False,
) -> None:
pragma_text = "omp"
if not omit_parallel:
pragma_text += " parallel"
pragma_text += " for"
if schedule is not None:
pragma_text += f" schedule({schedule})"
if num_threads is not None:
pragma_text += f" num_threads({str(num_threads)})"
if collapse is not None:
if collapse <= 0:
raise ValueError(
f"Invalid value for OpenMP `collapse` clause: {collapse}"
)
pragma_text += f" collapse({str(collapse)})"
self._insert_pragmas = InsertPragmasAtLoops(
ctx, [LoopPragma(pragma_text, nesting_depth)]
)
def __call__(self, node: PsAstNode) -> PsAstNode:
return self._insert_pragmas(node)
from __future__ import annotations
from textwrap import indent
from typing import cast, overload
from dataclasses import dataclass
from ...types import PsType, PsVectorType, PsBoolType, PsScalarType
from ..kernelcreation import KernelCreationContext, AstFactory
from ..memory import PsSymbol
from ..constants import PsConstant
from ..functions import PsMathFunction
from ..ast import PsAstNode
from ..ast.structural import (
PsBlock,
PsDeclaration,
PsAssignment,
PsLoop,
PsEmptyLeafMixIn,
)
from ..ast.expressions import (
PsExpression,
PsAddressOf,
PsCast,
PsUnOp,
PsBinOp,
PsSymbolExpr,
PsConstantExpr,
PsLiteral,
PsCall,
PsMemAcc,
PsBufferAcc,
PsSubscript,
PsAdd,
PsMul,
PsSub,
PsNeg,
PsDiv,
)
from ..ast.vector import PsVectorOp, PsVecBroadcast, PsVecMemAcc
from ..ast.analysis import UndefinedSymbolsCollector
from ..exceptions import PsInternalCompilerError, VectorizationError
@dataclass(frozen=True)
class VectorizationAxis:
"""Information about the iteration axis along which a subtree is being vectorized."""
counter: PsSymbol
"""Scalar iteration counter of this axis"""
vectorized_counter: PsSymbol | None = None
"""Vectorized iteration counter of this axis"""
step: PsExpression = PsExpression.make(PsConstant(1))
"""Step size of the scalar iteration"""
def get_vectorized_counter(self) -> PsSymbol:
if self.vectorized_counter is None:
raise PsInternalCompilerError(
"No vectorized counter defined on this vectorization axis"
)
return self.vectorized_counter
class VectorizationContext:
"""Context information for AST vectorization.
Args:
lanes: Number of vector lanes
axis: Iteration axis along which code is being vectorized
"""
def __init__(
self,
ctx: KernelCreationContext,
lanes: int,
axis: VectorizationAxis,
vectorized_symbols: dict[PsSymbol, PsSymbol] | None = None,
) -> None:
self._ctx = ctx
self._lanes = lanes
self._axis: VectorizationAxis = axis
self._vectorized_symbols: dict[PsSymbol, PsSymbol] = (
{**vectorized_symbols} if vectorized_symbols is not None else dict()
)
self._lane_mask: PsSymbol | None = None
if axis.vectorized_counter is not None:
self._vectorized_symbols[axis.counter] = axis.vectorized_counter
@property
def lanes(self) -> int:
"""Number of vector lanes"""
return self._lanes
@property
def axis(self) -> VectorizationAxis:
"""Iteration axis along which to vectorize"""
return self._axis
@property
def vectorized_symbols(self) -> dict[PsSymbol, PsSymbol]:
"""Dictionary mapping scalar symbols that are being vectorized to their vectorized copies"""
return self._vectorized_symbols
@property
def lane_mask(self) -> PsSymbol | None:
"""Symbol representing the current lane execution mask, or ``None`` if all lanes are active."""
return self._lane_mask
@lane_mask.setter
def lane_mask(self, mask: PsSymbol | None):
self._lane_mask = mask
def get_lane_mask_expr(self) -> PsExpression:
"""Retrieve an expression representing the current lane execution mask."""
if self._lane_mask is not None:
return PsExpression.make(self._lane_mask)
else:
return PsExpression.make(
PsConstant(True, PsVectorType(PsBoolType(), self._lanes))
)
def vectorize_symbol(self, symb: PsSymbol) -> PsSymbol:
"""Vectorize the given symbol of scalar type.
Creates a duplicate of the given symbol with vectorized data type,
adds it to the ``vectorized_symbols`` dict,
and returns the duplicate.
Raises:
VectorizationError: If the symbol's data type was not a `PsScalarType`,
or if the symbol was already vectorized
"""
if symb in self._vectorized_symbols:
raise VectorizationError(f"Symbol {symb} was already vectorized.")
vec_type = self.vector_type(symb.get_dtype())
vec_symb = self._ctx.duplicate_symbol(symb, vec_type)
self._vectorized_symbols[symb] = vec_symb
return vec_symb
def vector_type(self, scalar_type: PsType) -> PsVectorType:
"""Vectorize the given scalar data type.
Raises:
VectorizationError: If the given data type was not a `PsScalarType`.
"""
if not isinstance(scalar_type, PsScalarType):
raise VectorizationError(
f"Unable to vectorize type {scalar_type}: was not a scalar numeric type"
)
return PsVectorType(scalar_type, self._lanes)
def axis_ctr_dependees(self, symbols: set[PsSymbol]) -> set[PsSymbol]:
"""Returns all symbols in `symbols` that depend on the axis counter."""
return symbols & (self.vectorized_symbols.keys() | {self.axis.counter})
@dataclass
class Affine:
coeff: PsExpression
offset: PsExpression
def __neg__(self):
return Affine(-self.coeff, -self.offset)
def __add__(self, other: Affine):
return Affine(self.coeff + other.coeff, self.offset + other.offset)
def __sub__(self, other: Affine):
return Affine(self.coeff - other.coeff, self.offset - other.offset)
def __mul__(self, factor: PsExpression):
if not isinstance(factor, PsExpression):
return NotImplemented
return Affine(self.coeff * factor, self.offset * factor)
def __rmul__(self, factor: PsExpression):
if not isinstance(factor, PsExpression):
return NotImplemented
return Affine(self.coeff * factor, self.offset * factor)
def __truediv__(self, divisor: PsExpression):
if not isinstance(divisor, PsExpression):
return NotImplemented
return Affine(self.coeff / divisor, self.offset / divisor)
class AstVectorizer:
"""Transform a scalar subtree into a SIMD-parallel version of itself.
The `AstVectorizer` constructs a vectorized copy of a subtree by creating a SIMD-parallel
version of each of its nodes, one at a time.
It relies on information given in a `VectorizationContext` that defines the current environment,
including the vectorization axis, the number of vector lanes, and an execution mask determining
which vector lanes are active.
**Memory Accesses:**
The AST vectorizer is capable of vectorizing `PsMemAcc` and `PsBufferAcc` only under certain circumstances:
- If all indices are independent of both the vectorization axis' counter and any vectorized symbols,
the memory access is *lane-invariant*, and its result will be broadcast to all vector lanes.
- If at most one index depends on the axis counter via an affine expression, and does not depend on any
vectorized symbols, the memory access can be performed in parallel, either contiguously or strided,
and is replaced by a `PsVecMemAcc`.
- All other cases cause vectorization to fail.
**Legality:**
The AST vectorizer performs no legality checks and in particular assumes the absence of loop-carried
dependencies; i.e. all iterations of the vectorized subtree must already be independent of each
other, and insensitive to execution order.
**Result and Failures:**
The AST vectorizer does not alter the original subtree, but constructs and returns a copy of it.
Any symbols declared within the subtree are therein replaced by canonically renamed,
vectorized copies of themselves.
If the AST vectorizer is unable to transform a subtree, it raises a `VectorizationError`.
"""
def __init__(self, ctx: KernelCreationContext):
self._ctx = ctx
self._factory = AstFactory(ctx)
self._collect_symbols = UndefinedSymbolsCollector()
from ..kernelcreation import Typifier
from .eliminate_constants import EliminateConstants
from .lower_to_c import LowerToC
self._typifiy = Typifier(ctx)
self._fold_constants = EliminateConstants(ctx)
self._lower_to_c = LowerToC(ctx)
@overload
def __call__(self, node: PsBlock, vc: VectorizationContext) -> PsBlock:
pass
@overload
def __call__(self, node: PsDeclaration, vc: VectorizationContext) -> PsDeclaration:
pass
@overload
def __call__(self, node: PsAssignment, vc: VectorizationContext) -> PsAssignment:
pass
@overload
def __call__(self, node: PsExpression, vc: VectorizationContext) -> PsExpression:
pass
@overload
def __call__(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode:
pass
def __call__(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode:
"""Perform subtree vectorization.
Args:
node: Root of the subtree that should be vectorized
vc: Object describing the current vectorization context
Raises:
VectorizationError: If a node cannot be vectorized
"""
return self.visit(node, vc)
def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode:
"""Vectorize a subtree."""
match node:
case PsBlock(stmts):
return PsBlock([self.visit(n, vc) for n in stmts])
case PsExpression():
return self.visit_expr(node, vc)
case PsDeclaration(_, rhs):
vec_symb = vc.vectorize_symbol(node.declared_symbol)
vec_lhs = PsExpression.make(vec_symb)
vec_rhs = self.visit_expr(rhs, vc)
return PsDeclaration(vec_lhs, vec_rhs)
case PsAssignment(lhs, rhs):
if (
isinstance(lhs, PsSymbolExpr)
and lhs.symbol in vc.vectorized_symbols
):
return PsAssignment(
self.visit_expr(lhs, vc), self.visit_expr(rhs, vc)
)
if not isinstance(lhs, (PsMemAcc, PsBufferAcc)):
raise VectorizationError(f"Unable to vectorize assignment to {lhs}")
lhs_vec = self.visit_expr(lhs, vc)
if not isinstance(lhs_vec, PsVecMemAcc):
raise VectorizationError(
f"Unable to vectorize memory write {node}:\n"
f"Index did not depend on axis counter."
)
rhs_vec = self.visit_expr(rhs, vc)
return PsAssignment(lhs_vec, rhs_vec)
case PsLoop(counter, start, stop, step, body):
# Check that loop bounds are lane-invariant
free_symbols = (
self._collect_symbols(start)
| self._collect_symbols(stop)
| self._collect_symbols(step)
)
vec_dependencies = vc.axis_ctr_dependees(free_symbols)
if vec_dependencies:
raise VectorizationError(
"Unable to vectorize loop depending on vectorized symbols:\n"
f" Offending dependencies:\n"
f" {vec_dependencies}\n"
f" Found in loop:\n"
f"{indent(str(node), ' ')}"
)
vectorized_body = cast(PsBlock, self.visit(body, vc))
return PsLoop(counter, start, stop, step, vectorized_body)
case PsEmptyLeafMixIn():
return node
case _:
raise NotImplementedError(f"Vectorization of {node} is not implemented")
def visit_expr(self, expr: PsExpression, vc: VectorizationContext) -> PsExpression:
"""Vectorize an expression."""
vec_expr: PsExpression
scalar_type = expr.get_dtype()
match expr:
# Invalids
case PsVectorOp() | PsAddressOf():
raise VectorizationError(f"Unable to vectorize {type(expr)}: {expr}")
# Symbols
case PsSymbolExpr(symb) if symb in vc.vectorized_symbols:
# Vectorize symbol
vector_symb = vc.vectorized_symbols[symb]
vec_expr = PsSymbolExpr(vector_symb)
case PsSymbolExpr(symb) if symb == vc.axis.counter:
raise VectorizationError(
f"Unable to vectorize occurence of axis counter {symb} "
"since no vectorized version of the counter was present in the context."
)
# Symbols, constants, and literals that can be broadcast
case PsSymbolExpr() | PsConstantExpr() | PsLiteral():
if isinstance(expr.dtype, PsScalarType):
# Broadcast constant or non-vectorized scalar symbol
vec_expr = PsVecBroadcast(vc.lanes, expr.clone())
else:
# Cannot vectorize non-scalar constants or symbols
raise VectorizationError(
f"Unable to vectorize expression {expr} of non-scalar data type {expr.dtype}"
)
# Unary Ops
case PsCast(target_type, operand):
vec_expr = PsCast(
vc.vector_type(target_type), self.visit_expr(operand, vc)
)
case PsUnOp(operand):
vec_expr = type(expr)(self.visit_expr(operand, vc))
# Binary Ops
case PsBinOp(op1, op2):
vec_expr = type(expr)(
self.visit_expr(op1, vc), self.visit_expr(op2, vc)
)
# Math Functions
case PsCall(PsMathFunction(func), func_args):
vec_expr = PsCall(
PsMathFunction(func),
[self.visit_expr(arg, vc) for arg in func_args],
)
# Other Functions
case PsCall(func, _):
raise VectorizationError(
f"Unable to vectorize function call to {func}."
)
# Memory Accesses
case PsMemAcc(ptr, offset):
if not isinstance(ptr, PsSymbolExpr):
raise VectorizationError(
f"Unable to vectorize memory access by non-symbol pointer {ptr}"
)
idx_affine = self._index_as_affine(offset, vc)
if idx_affine is None:
vec_expr = PsVecBroadcast(vc.lanes, expr.clone())
else:
stride: PsExpression | None = self._fold_constants(
self._typifiy(idx_affine.coeff * vc.axis.step)
)
if (
isinstance(stride, PsConstantExpr)
and stride.constant.value == 1
):
# Contiguous access
stride = None
vec_expr = PsVecMemAcc(
ptr.clone(), offset.clone(), vc.lanes, stride
)
case PsBufferAcc(ptr, indices):
buf = expr.buffer
ctr_found = False
access_stride: PsExpression | None = None
for i, idx in enumerate(indices):
idx_affine = self._index_as_affine(idx, vc)
if idx_affine is not None:
if ctr_found:
raise VectorizationError(
f"Unable to vectorize buffer access {expr}: "
f"Found multiple indices that depend on iteration counter {vc.axis.counter}."
)
ctr_found = True
access_stride = stride = self._fold_constants(
self._typifiy(
idx_affine.coeff
* vc.axis.step
* PsExpression.make(buf.strides[i])
)
)
if ctr_found:
# Buffer access must be vectorized
assert access_stride is not None
if (
isinstance(access_stride, PsConstantExpr)
and access_stride.constant.value == 1
):
# Contiguous access
access_stride = None
linearized_acc = self._lower_to_c(expr)
assert isinstance(linearized_acc, PsMemAcc)
vec_expr = PsVecMemAcc(
ptr.clone(),
linearized_acc.offset.clone(),
vc.lanes,
access_stride,
)
else:
# Buffer access is lane-invariant
vec_expr = PsVecBroadcast(vc.lanes, expr.clone())
case PsSubscript(array, index):
# Check that array expression and indices are lane-invariant
free_symbols = self._collect_symbols(array).union(
*[self._collect_symbols(i) for i in index]
)
vec_dependencies = vc.axis_ctr_dependees(free_symbols)
if vec_dependencies:
raise VectorizationError(
"Unable to vectorize array subscript depending on vectorized symbols:\n"
f" Offending dependencies:\n"
f" {vec_dependencies}\n"
f" Found in expression:\n"
f"{indent(str(expr), ' ')}"
)
vec_expr = PsVecBroadcast(vc.lanes, expr.clone())
case _:
raise NotImplementedError(
f"Vectorization of {type(expr)} is not implemented"
)
vec_expr.dtype = vc.vector_type(scalar_type)
return vec_expr
def _index_as_affine(
self, idx: PsExpression, vc: VectorizationContext
) -> Affine | None:
"""Attempt to analyze an index expression as an affine expression of the axis counter."""
free_symbols = self._collect_symbols(idx)
# Check if all symbols except for the axis counter are lane-invariant
for symb in free_symbols:
if symb != vc.axis.counter and symb in vc.vectorized_symbols:
raise VectorizationError(
"Unable to rewrite index as affine expression of axis counter: \n"
f" {idx}\n"
f"Expression depends on non-lane-invariant symbol {symb}"
)
if vc.axis.counter not in free_symbols:
# Index is lane-invariant
return None
zero = self._factory.parse_index(0)
one = self._factory.parse_index(1)
def lane_invariant(expr) -> bool:
return vc.axis.counter not in self._collect_symbols(expr)
def collect(subexpr) -> Affine:
match subexpr:
case PsSymbolExpr(symb) if symb == vc.axis.counter:
return Affine(one, zero)
case _ if lane_invariant(subexpr):
return Affine(zero, subexpr)
case PsNeg(op):
return -collect(op)
case PsAdd(op1, op2):
return collect(op1) + collect(op2)
case PsSub(op1, op2):
return collect(op1) - collect(op2)
case PsMul(op1, op2) if lane_invariant(op1):
return op1 * collect(op2)
case PsMul(op1, op2) if lane_invariant(op2):
return collect(op1) * op2
case PsDiv(op1, op2) if lane_invariant(op2):
return collect(op1) / op2
case _:
raise VectorizationError(
"Unable to rewrite index as affine expression of axis counter: \n"
f" {idx}\n"
f"Encountered invalid subexpression {subexpr}"
)
return collect(idx)
from typing import TypeVar, cast
from ..kernelcreation import KernelCreationContext
from ..memory import PsSymbol
from ..exceptions import PsInternalCompilerError
from ..ast import PsAstNode
from ..ast.structural import (
PsBlock,
PsConditional,
PsLoop,
PsDeclaration,
PsAssignment,
PsComment,
PsPragma,
PsStatement,
)
from ..ast.expressions import PsExpression, PsSymbolExpr
__all__ = ["CanonicalClone"]
class CloneContext:
def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx
self._dup_table: dict[PsSymbol, PsSymbol] = dict()
def symbol_decl(self, declared_symbol: PsSymbol):
self._dup_table[declared_symbol] = self._ctx.duplicate_symbol(declared_symbol)
def get_replacement(self, symb: PsSymbol):
return self._dup_table.get(symb, symb)
Node_T = TypeVar("Node_T", bound=PsAstNode)
class CanonicalClone:
"""Clone a subtree, and rename all symbols declared inside it to retain canonicality."""
def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx
def __call__(self, node: Node_T) -> Node_T:
return self.visit(node, CloneContext(self._ctx))
def visit(self, node: Node_T, cc: CloneContext) -> Node_T:
match node:
case PsBlock(statements):
return cast(
Node_T, PsBlock([self.visit(stmt, cc) for stmt in statements])
)
case PsLoop(ctr, start, stop, step, body):
cc.symbol_decl(ctr.symbol)
return cast(
Node_T,
PsLoop(
self.visit(ctr, cc),
self.visit(start, cc),
self.visit(stop, cc),
self.visit(step, cc),
self.visit(body, cc),
),
)
case PsConditional(cond, then, els):
return cast(
Node_T,
PsConditional(
self.visit(cond, cc),
self.visit(then, cc),
self.visit(els, cc) if els is not None else None,
),
)
case PsComment() | PsPragma():
return cast(Node_T, node.clone())
case PsDeclaration(lhs, rhs):
cc.symbol_decl(node.declared_symbol)
return cast(
Node_T,
PsDeclaration(
cast(PsSymbolExpr, self.visit(lhs, cc)),
self.visit(rhs, cc),
),
)
case PsAssignment(lhs, rhs):
return cast(
Node_T,
PsAssignment(
self.visit(lhs, cc),
self.visit(rhs, cc),
),
)
case PsExpression():
expr_clone = node.clone()
self._replace_symbols(expr_clone, cc)
return cast(Node_T, expr_clone)
case PsStatement(expr):
return cast(Node_T, PsStatement(self.visit(expr, cc)))
case _:
raise PsInternalCompilerError(
f"Don't know how to canonically clone {type(node)}"
)
def _replace_symbols(self, expr: PsExpression, cc: CloneContext):
if isinstance(expr, PsSymbolExpr):
expr.symbol = cc.get_replacement(expr.symbol)
else:
for c in expr.children:
self._replace_symbols(cast(PsExpression, c), cc)
from ..kernelcreation import KernelCreationContext
from ..memory import PsSymbol
from ..exceptions import PsInternalCompilerError
from ..ast import PsAstNode
from ..ast.structural import PsDeclaration, PsAssignment, PsLoop, PsConditional, PsBlock
from ..ast.expressions import PsSymbolExpr, PsExpression
from ...types import constify
__all__ = ["CanonicalizeSymbols"]
class CanonContext:
def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx
self.encountered_symbols: set[PsSymbol] = set()
self.live_symbols_map: dict[PsSymbol, PsSymbol] = dict()
self.updated_symbols: set[PsSymbol] = set()
def deduplicate(self, symb: PsSymbol) -> PsSymbol:
if symb in self.live_symbols_map:
return self.live_symbols_map[symb]
elif symb not in self.encountered_symbols:
self.encountered_symbols.add(symb)
self.live_symbols_map[symb] = symb
return symb
else:
replacement = self._ctx.duplicate_symbol(symb)
self.live_symbols_map[symb] = replacement
self.encountered_symbols.add(replacement)
return replacement
def mark_as_updated(self, symb: PsSymbol):
self.updated_symbols.add(symb)
def is_live(self, symb: PsSymbol) -> bool:
return symb in self.live_symbols_map
def end_lifespan(self, symb: PsSymbol):
if symb in self.live_symbols_map:
del self.live_symbols_map[symb]
class CanonicalizeSymbols:
"""Remove duplicate symbol declarations and declare all non-updated symbols ``const``.
The `CanonicalizeSymbols` pass will remove multiple declarations of the same symbol by
renaming all but the last occurence, and will optionally ``const``-qualify all symbols
encountered in the AST that are never updated.
"""
def __init__(self, ctx: KernelCreationContext, constify: bool = True) -> None:
self._ctx = ctx
self._constify = constify
self._last_result: CanonContext | None = None
def get_last_live_symbols(self) -> set[PsSymbol]:
if self._last_result is None:
raise PsInternalCompilerError("Pass was not executed yet")
return set(self._last_result.live_symbols_map.values())
def __call__(self, node: PsAstNode) -> PsAstNode:
cc = CanonContext(self._ctx)
self.visit(node, cc)
# Any symbol encountered but never updated can be marked const
if self._constify:
for symb in cc.encountered_symbols - cc.updated_symbols:
if symb.dtype is not None:
symb.dtype = constify(symb.dtype)
# Any symbols still alive now are function params or globals
self._last_result = cc
return node
def visit(self, node: PsAstNode, cc: CanonContext):
"""Traverse the AST in reverse pre-order to collect, deduplicate, and maybe constify all live symbols."""
match node:
case PsSymbolExpr(symb):
node.symbol = cc.deduplicate(symb)
return node
case PsExpression():
for c in node.children:
self.visit(c, cc)
case PsDeclaration(lhs, rhs):
decl_symb = node.declared_symbol
self.visit(lhs, cc)
self.visit(rhs, cc)
cc.end_lifespan(decl_symb)
case PsAssignment(lhs, rhs):
self.visit(lhs, cc)
self.visit(rhs, cc)
if isinstance(lhs, PsSymbolExpr):
cc.mark_as_updated(lhs.symbol)
case PsLoop(ctr, _, _, _, _):
decl_symb = ctr.symbol
for c in node.children[::-1]:
self.visit(c, cc)
cc.mark_as_updated(ctr.symbol)
cc.end_lifespan(decl_symb)
case PsConditional(cond, then, els):
if els is not None:
self.visit(els, cc)
self.visit(then, cc)
self.visit(cond, cc)
case PsBlock(statements):
for stmt in statements[::-1]:
self.visit(stmt, cc)
from ..kernelcreation import KernelCreationContext
from ..ast import PsAstNode
from ..ast.analysis import collect_undefined_symbols
from ..ast.structural import PsLoop, PsBlock, PsConditional
from ..ast.expressions import (
PsAnd,
PsCast,
PsConstant,
PsConstantExpr,
PsDiv,
PsEq,
PsExpression,
PsGe,
PsGt,
PsIntDiv,
PsLe,
PsLt,
PsMul,
PsNe,
PsNeg,
PsNot,
PsOr,
PsSub,
PsSymbolExpr,
PsAdd,
)
from .eliminate_constants import EliminateConstants
from ...types import PsBoolType, PsIntegerType
__all__ = ["EliminateBranches"]
class IslAnalysisError(Exception):
"""Indicates a fatal error during integer set analysis (based on islpy)"""
class BranchElimContext:
def __init__(self) -> None:
self.enclosing_loops: list[PsLoop] = []
self.enclosing_conditions: list[PsExpression] = []
class EliminateBranches:
"""Replace conditional branches by their then- or else-branch if their condition can be unequivocally
evaluated.
This pass will attempt to evaluate branch conditions within their context in the AST, and replace
conditionals by either their then- or their else-block if the branch is unequivocal.
If islpy is installed, this pass will incorporate information about the iteration regions
of enclosing loops and enclosing conditionals into its analysis.
Args:
use_isl (bool, optional): enable islpy based analysis (default: True)
"""
def __init__(self, ctx: KernelCreationContext, use_isl: bool = True) -> None:
self._ctx = ctx
self._use_isl = use_isl
self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False)
def __call__(self, node: PsAstNode) -> PsAstNode:
return self.visit(node, BranchElimContext())
def visit(self, node: PsAstNode, ec: BranchElimContext) -> PsAstNode:
match node:
case PsLoop(_, _, _, _, body):
ec.enclosing_loops.append(node)
self.visit(body, ec)
ec.enclosing_loops.pop()
case PsBlock(statements):
statements_new: list[PsAstNode] = []
for stmt in statements:
statements_new.append(self.visit(stmt, ec))
node.statements = statements_new
case PsConditional():
result = self.handle_conditional(node, ec)
match result:
case PsConditional(_, branch_true, branch_false):
ec.enclosing_conditions.append(result.condition)
self.visit(branch_true, ec)
ec.enclosing_conditions.pop()
if branch_false is not None:
ec.enclosing_conditions.append(PsNot(result.condition))
self.visit(branch_false, ec)
ec.enclosing_conditions.pop()
case PsBlock():
self.visit(result, ec)
case None:
result = PsBlock([])
case _:
assert False, "unreachable code"
return result
return node
def handle_conditional(
self, conditional: PsConditional, ec: BranchElimContext
) -> PsConditional | PsBlock | None:
condition_simplified = self._elim_constants(conditional.condition)
if self._use_isl:
condition_simplified = self._isl_simplify_condition(
condition_simplified, ec
)
match condition_simplified:
case PsConstantExpr(c) if c.value:
return conditional.branch_true
case PsConstantExpr(c) if not c.value:
return conditional.branch_false
return conditional
def _isl_simplify_condition(
self, condition: PsExpression, ec: BranchElimContext
) -> PsExpression:
"""If installed, use ISL to simplify the passed condition to true or
false based on enclosing loops and conditionals. If no simplification
can be made or ISL is not installed, the original condition is returned.
"""
try:
import islpy as isl
except ImportError:
return condition
def printer(expr: PsExpression):
match expr:
case PsSymbolExpr(symbol):
return symbol.name
case PsConstantExpr(constant):
dtype = constant.get_dtype()
if not isinstance(dtype, (PsIntegerType, PsBoolType)):
raise IslAnalysisError(
"Only scalar integer and bool constant may appear in isl expressions."
)
return str(constant.value)
case PsAdd(op1, op2):
return f"({printer(op1)} + {printer(op2)})"
case PsSub(op1, op2):
return f"({printer(op1)} - {printer(op2)})"
case PsMul(op1, op2):
return f"({printer(op1)} * {printer(op2)})"
case PsDiv(op1, op2) | PsIntDiv(op1, op2):
return f"({printer(op1)} / {printer(op2)})"
case PsAnd(op1, op2):
return f"({printer(op1)} and {printer(op2)})"
case PsOr(op1, op2):
return f"({printer(op1)} or {printer(op2)})"
case PsEq(op1, op2):
return f"({printer(op1)} = {printer(op2)})"
case PsNe(op1, op2):
return f"({printer(op1)} != {printer(op2)})"
case PsGt(op1, op2):
return f"({printer(op1)} > {printer(op2)})"
case PsGe(op1, op2):
return f"({printer(op1)} >= {printer(op2)})"
case PsLt(op1, op2):
return f"({printer(op1)} < {printer(op2)})"
case PsLe(op1, op2):
return f"({printer(op1)} <= {printer(op2)})"
case PsNeg(operand):
return f"(-{printer(operand)})"
case PsNot(operand):
return f"(not {printer(operand)})"
case PsCast(_, operand):
return printer(operand)
case _:
raise IslAnalysisError(
f"Not supported by isl or don't know how to print {expr}"
)
dofs = collect_undefined_symbols(condition)
outer_conditions = []
for loop in ec.enclosing_loops:
if not (
isinstance(loop.step, PsConstantExpr)
and loop.step.constant.value == 1
):
raise IslAnalysisError(
"Loops with strides != 1 are not yet supported."
)
dofs.add(loop.counter.symbol)
dofs.update(collect_undefined_symbols(loop.start))
dofs.update(collect_undefined_symbols(loop.stop))
loop_start_str = printer(loop.start)
loop_stop_str = printer(loop.stop)
ctr_name = loop.counter.symbol.name
outer_conditions.append(
f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}"
)
for cond in ec.enclosing_conditions:
dofs.update(collect_undefined_symbols(cond))
outer_conditions.append(printer(cond))
dofs_str = ",".join(dof.name for dof in dofs)
outer_conditions_str = " and ".join(outer_conditions)
condition_str = printer(condition)
outer_set = isl.BasicSet(f"{{ [{dofs_str}] : {outer_conditions_str} }}")
inner_set = isl.BasicSet(f"{{ [{dofs_str}] : {condition_str} }}")
if inner_set.is_empty():
return PsExpression.make(PsConstant(False))
intersection = outer_set.intersect(inner_set)
if intersection.is_empty():
return PsExpression.make(PsConstant(False))
elif intersection == outer_set:
return PsExpression.make(PsConstant(True))
else:
return condition
from typing import cast, Iterable, overload
from collections import defaultdict
import numpy as np
from ..kernelcreation import KernelCreationContext, Typifier
from ..ast import PsAstNode
from ..ast.structural import PsBlock, PsDeclaration
from ..ast.expressions import (
PsExpression,
PsConstantExpr,
PsSymbolExpr,
PsLiteralExpr,
PsBinOp,
PsAdd,
PsSub,
PsMul,
PsDiv,
PsIntDiv,
PsRem,
PsAnd,
PsOr,
PsRel,
PsNeg,
PsNot,
PsCall,
PsEq,
PsGe,
PsLe,
PsLt,
PsGt,
PsNe,
PsTernary,
PsCast,
)
from ..ast.vector import PsVecBroadcast
from ..ast.util import AstEqWrapper
from ..constants import PsConstant
from ..memory import PsSymbol
from ..functions import PsMathFunction
from ...types import PsNumericType, PsBoolType, PsScalarType, PsVectorType, constify
__all__ = ["EliminateConstants"]
class ECContext:
def __init__(self, ctx: KernelCreationContext):
self._ctx = ctx
self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict()
from ..emission import IRAstPrinter
self._printer = IRAstPrinter(indent_width=0, annotate_constants=False)
@property
def extractions(self) -> Iterable[tuple[PsSymbol, PsExpression]]:
return [
(symb, cast(PsExpression, w.n))
for (w, symb) in self._extracted_constants.items()
]
def _get_symb_name(self, expr: PsExpression):
code = self._printer(expr)
code = code.lower()
# remove spaces
code = "".join(code.split())
def valid_char(c):
return (ord("0") <= ord(c) <= ord("9")) or (ord("a") <= ord(c) <= ord("z"))
charmap = {"+": "p", "-": "s", "*": "m", "/": "o"}
charmap = defaultdict(lambda: "_", charmap) # type: ignore
code = "".join((c if valid_char(c) else charmap[c]) for c in code)
return f"__c_{code}"
def extract_expression(self, expr: PsExpression) -> PsSymbolExpr:
dtype = expr.get_dtype()
expr_wrapped = AstEqWrapper(expr)
if expr_wrapped not in self._extracted_constants:
symb_name = self._get_symb_name(expr)
symb = self._ctx.get_new_symbol(symb_name, constify(dtype))
self._extracted_constants[expr_wrapped] = symb
else:
symb = self._extracted_constants[expr_wrapped]
return PsSymbolExpr(symb)
class EliminateConstants:
"""Eliminate constant expressions in various ways.
- Constant folding: Nontrivial constant integer (and optionally floating point) expressions
are evaluated and replaced by their result
- Idempotence elimination: Idempotent operations (e.g. addition of zero, multiplication with one)
are replaced by their result
- Dominance elimination: Multiplication by zero is replaced by zero
- Constant extraction: Optionally, nontrivial constant expressions are extracted and listed at the beginning of
the outermost block.
"""
def __init__(
self,
ctx: KernelCreationContext,
extract_constant_exprs: bool = False,
fold_integers: bool = True,
fold_relations: bool = True,
fold_floats: bool = False,
):
self._ctx = ctx
self._typify = Typifier(ctx)
self._fold_integers = fold_integers
self._fold_relations = fold_relations
self._fold_floats = fold_floats
self._extract_constant_exprs = extract_constant_exprs
@overload
def __call__(self, node: PsExpression) -> PsExpression:
pass
@overload
def __call__(self, node: PsBlock) -> PsBlock:
pass
@overload
def __call__(self, node: PsAstNode) -> PsAstNode:
pass
def __call__(self, node: PsAstNode) -> PsAstNode:
ecc = ECContext(self._ctx)
node = self.visit(node, ecc)
if ecc.extractions:
prepend_decls = [
PsDeclaration(PsExpression.make(symb), expr)
for symb, expr in ecc.extractions
]
if not isinstance(node, PsBlock):
node = PsBlock(prepend_decls + [node])
else:
node.children = prepend_decls + list(node.children)
return node
def visit(self, node: PsAstNode, ecc: ECContext) -> PsAstNode:
match node:
case PsExpression():
transformed_expr, _ = self.visit_expr(node, ecc)
return transformed_expr
case _:
node.children = [self.visit(c, ecc) for c in node.children]
return node
def visit_expr(
self, expr: PsExpression, ecc: ECContext
) -> tuple[PsExpression, bool]:
"""Transformation of expressions.
Returns:
(transformed_expr, is_const): The tranformed expression, and a flag indicating whether it is constant
"""
# Return constants and literals as they are
if isinstance(expr, (PsConstantExpr, PsLiteralExpr)):
return expr, True
# Shortcut symbols
if isinstance(expr, PsSymbolExpr):
return expr, False
subtree_results = [
self.visit_expr(cast(PsExpression, c), ecc) for c in expr.children
]
expr.children = [r[0] for r in subtree_results]
subtree_constness = [r[1] for r in subtree_results]
# Eliminate idempotence, dominance. constant (broad)casts, and trivial relations
match expr:
# Additive idempotence: Addition and subtraction of zero
case PsAdd(PsConstantExpr(c), other_op) if np.all(c.value == 0):
return other_op, all(subtree_constness)
case PsAdd(other_op, PsConstantExpr(c)) if np.all(c.value == 0):
return other_op, all(subtree_constness)
case PsSub(other_op, PsConstantExpr(c)) if np.all(c.value == 0):
return other_op, all(subtree_constness)
# Additive idempotence: Subtraction from zero
case PsSub(PsConstantExpr(c), other_op) if np.all(c.value == 0):
other_transformed, is_const = self.visit_expr(
self._typify(-other_op), ecc
)
return other_transformed, is_const
# Multiplicative idempotence: Multiplication with and division by one
case PsMul(PsConstantExpr(c), other_op) if np.all(c.value == 1):
return other_op, all(subtree_constness)
case PsMul(other_op, PsConstantExpr(c)) if np.all(c.value == 1):
return other_op, all(subtree_constness)
case PsDiv(other_op, PsConstantExpr(c)) | PsIntDiv(
other_op, PsConstantExpr(c)
) if np.all(c.value == 1):
return other_op, all(subtree_constness)
# Trivial remainder at division by one
case PsRem(other_op, PsConstantExpr(c)) if np.all(c.value == 1):
zero = self._typify(PsConstantExpr(PsConstant(0, c.get_dtype())))
return zero, True
# Multiplicative dominance: 0 * x = 0
case PsMul(PsConstantExpr(c), other_op) if np.all(c.value == 0):
return PsConstantExpr(c), True
case PsMul(other_op, PsConstantExpr(c)) if np.all(c.value == 0):
return PsConstantExpr(c), True
# Logical idempotence
case PsAnd(PsConstantExpr(c), other_op) if np.all(c.value):
return other_op, all(subtree_constness)
case PsAnd(other_op, PsConstantExpr(c)) if np.all(c.value):
return other_op, all(subtree_constness)
case PsOr(PsConstantExpr(c), other_op) if not np.any(c.value):
return other_op, all(subtree_constness)
case PsOr(other_op, PsConstantExpr(c)) if not np.any(c.value):
return other_op, all(subtree_constness)
# Logical dominance
case PsAnd(PsConstantExpr(c), other_op) if not np.any(c.value):
return PsConstantExpr(c), True
case PsAnd(other_op, PsConstantExpr(c)) if not np.any(c.value):
return PsConstantExpr(c), True
case PsOr(PsConstantExpr(c), other_op) if np.all(c.value):
return PsConstantExpr(c), True
case PsOr(other_op, PsConstantExpr(c)) if np.all(c.value):
return PsConstantExpr(c), True
# Trivial (broad)casts
case PsCast(target_type, PsConstantExpr(c)):
assert isinstance(target_type, PsNumericType)
return PsConstantExpr(c.reinterpret_as(target_type)), True
case PsCast(target_type, op) if target_type == op.get_dtype():
return op, all(subtree_constness)
case PsVecBroadcast(lanes, PsConstantExpr(c)):
scalar_type = c.get_dtype()
assert isinstance(scalar_type, PsScalarType)
vec_type = PsVectorType(scalar_type, lanes)
return PsConstantExpr(PsConstant(c.value, vec_type)), True
# Trivial comparisons
case (
PsEq(op1, op2) | PsGe(op1, op2) | PsLe(op1, op2)
) if op1.structurally_equal(op2):
arg_dtype = op1.get_dtype()
bool_type = (
PsVectorType(PsBoolType(), arg_dtype.vector_entries)
if isinstance(arg_dtype, PsVectorType)
else PsBoolType()
)
true = self._typify(PsConstantExpr(PsConstant(True, bool_type)))
return true, True
case (
PsNe(op1, op2) | PsGt(op1, op2) | PsLt(op1, op2)
) if op1.structurally_equal(op2):
arg_dtype = op1.get_dtype()
bool_type = (
PsVectorType(PsBoolType(), arg_dtype.vector_entries)
if isinstance(arg_dtype, PsVectorType)
else PsBoolType()
)
false = self._typify(PsConstantExpr(PsConstant(False, bool_type)))
return false, True
# Trivial ternaries
case PsTernary(PsConstantExpr(c), then, els):
if c.value:
return then, subtree_constness[1]
else:
return els, subtree_constness[2]
# end match: no idempotence or dominance encountered
# Detect constant expressions
if all(subtree_constness):
dtype = expr.get_dtype()
is_rel = isinstance(expr, PsRel)
if isinstance(dtype, PsNumericType):
is_int = dtype.is_int()
is_float = dtype.is_float()
is_bool = dtype.is_bool()
else:
is_int = is_float = is_bool = False
do_fold = (
is_bool
or (self._fold_integers and is_int)
or (self._fold_floats and is_float)
or (self._fold_relations and is_rel)
)
folded: PsConstant | None
match expr:
case PsNeg(operand) | PsNot(operand):
if isinstance(operand, PsConstantExpr):
val = operand.constant.value
py_operator = expr.python_operator
if do_fold and py_operator is not None:
assert isinstance(dtype, PsNumericType)
folded = PsConstant(py_operator(val), dtype)
return self._typify(PsConstantExpr(folded)), True
return expr, True
case PsBinOp(op1, op2):
if isinstance(op1, PsConstantExpr) and isinstance(
op2, PsConstantExpr
):
v1 = op1.constant.value
v2 = op2.constant.value
if do_fold:
assert isinstance(dtype, PsNumericType)
py_operator = expr.python_operator
folded = None
if py_operator is not None:
folded = PsConstant(
py_operator(v1, v2),
dtype,
)
elif isinstance(expr, PsDiv):
if is_int:
from ...utils import c_intdiv
folded = PsConstant(c_intdiv(v1, v2), dtype)
elif (
isinstance(dtype, PsNumericType)
and dtype.is_float()
):
folded = PsConstant(v1 / v2, dtype)
if folded is not None:
return self._typify(PsConstantExpr(folded)), True
return expr, True
case PsCall(PsMathFunction(), _):
# TODO: Some math functions (min/max) might be safely folded
return expr, True
# end if: this expression is not constant
# If required, extract constant subexpressions
if self._extract_constant_exprs:
for i, (child, is_const) in enumerate(subtree_results):
if is_const and not isinstance(child, (PsConstantExpr, PsLiteralExpr)):
replacement = ecc.extract_expression(child)
expr.set_child(i, replacement)
# Any other expressions are not considered constant even if their arguments are
return expr, False