Skip to content
Snippets Groups Projects
Commit 19d0df07 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

constant elimination, part 2: subexpression extraction

parent d9a260ef
No related branches found
No related tags found
No related merge requests found
Pipeline #63824 failed
......@@ -14,6 +14,14 @@ class PsBlock(PsAstNode):
def __init__(self, cs: Sequence[PsAstNode]):
self._statements = list(cs)
@property
def children(self) -> Sequence[PsAstNode]:
return self.get_children()
@children.setter
def children(self, cs: Sequence[PsAstNode]):
self._statements = list(cs)
def get_children(self) -> tuple[PsAstNode, ...]:
return tuple(self._statements)
......
......@@ -11,8 +11,9 @@ def failing_cast(target: type | tuple[type, ...], obj: Any) -> Any:
return obj
class EqWrapper:
"""Wrapper around AST nodes that maps the `__eq__` method onto `structurally_equal`.
class AstEqWrapper:
"""Wrapper around AST nodes that computes a hash from the AST's textual representation
and maps the `__eq__` method onto `structurally_equal`.
Useful in dictionaries when the goal is to keep track of subtrees according to their
structure, e.g. in elimination of constants or common subexpressions.
......@@ -26,10 +27,12 @@ class EqWrapper:
return self._node
def __eq__(self, other: object) -> bool:
if not isinstance(other, EqWrapper):
if not isinstance(other, AstEqWrapper):
return False
return self._node.structurally_equal(other._node)
def __hash__(self) -> int:
return hash(self._node)
# TODO: consider replacing this with smth. more performant
# TODO: Check that repr is implemented by all AST nodes
return hash(repr(self._node))
......@@ -150,9 +150,10 @@ class FullIterationSpace(IterationSpace):
if isinstance(expr, int):
return PsConstantExpr(PsConstant(expr, ctx.index_dtype))
elif isinstance(expr, sp.Expr):
return typifier.typify_expression(
typed_expr, _ = typifier.typify_expression(
freeze.freeze_expression(expr), ctx.index_dtype
)
return typed_expr
else:
raise ValueError(f"Invalid entry in slice: {expr}")
......
......@@ -128,10 +128,14 @@ class Typifier:
def typify_expression(
self, expr: PsExpression, target_type: PsNumericType | None = None
) -> PsExpression:
) -> tuple[PsExpression, PsType]:
tc = TypeContext(target_type)
self.visit_expr(expr, tc)
return expr
if tc.target_type is None:
raise TypificationError(f"Unable to determine type for {expr}")
return expr, tc.target_type
def visit(self, node: PsAstNode) -> None:
"""Recursive processing of structural nodes"""
......
from typing import cast
from typing import cast, Iterable
from collections import defaultdict
from ..kernelcreation.context import KernelCreationContext
from ..kernelcreation import KernelCreationContext, Typifier
from ..ast import PsAstNode
from ..ast.structural import PsBlock, PsDeclaration
from ..ast.expressions import (
PsExpression,
PsConstantExpr,
......@@ -13,17 +15,63 @@ from ..ast.expressions import (
PsMul,
PsDiv,
)
from ..ast.util import AstEqWrapper
from ..constants import PsConstant
from ...types import PsIntegerType, PsIeeeFloatType
from ..symbols import PsSymbol
from ...types import PsIntegerType, PsIeeeFloatType, PsTypeError
from ..emission import CAstPrinter
__all__ = ["EliminateConstants"]
class ECContext:
def __init__(self):
pass
def __init__(self, ctx: KernelCreationContext):
self._ctx = ctx
self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict()
self._typifier = Typifier(ctx)
self._printer = CAstPrinter(0)
@property
def extractions(self) -> Iterable[tuple[PsSymbol, PsExpression]]:
return [
(symb, cast(PsExpression, w.n))
for (w, symb) in self._extracted_constants.items()
]
def _get_symb_name(self, expr: PsExpression):
code = self._printer(expr)
code = code.lower()
# remove spaces
code = "".join(code.split())
def valid_char(c):
return (ord("0") <= ord(c) <= ord("9")) or (ord("a") <= ord(c) <= ord("z"))
charmap = {"+": "p", "-": "s", "*": "m", "/": "o"}
charmap = defaultdict(lambda: "_", charmap) # type: ignore
code = "".join((c if valid_char(c) else charmap[c]) for c in code)
return f"__c_{code}"
def extract_expression(self, expr: PsExpression) -> PsSymbolExpr:
expr, dtype = self._typifier.typify_expression(expr)
expr_wrapped = AstEqWrapper(expr)
if expr_wrapped not in self._extracted_constants:
symb_name = self._get_symb_name(expr)
try:
symb = self._ctx.get_symbol(symb_name, dtype)
except PsTypeError:
symb = self._ctx.get_symbol(f"{symb_name}_{dtype.c_string()}", dtype)
self._extracted_constants[expr_wrapped] = symb
else:
symb = self._extracted_constants[expr_wrapped]
return PsSymbolExpr(symb)
class EliminateConstants:
......@@ -38,26 +86,45 @@ class EliminateConstants:
the outermost block.
"""
def __init__(self, ctx: KernelCreationContext):
def __init__(
self, ctx: KernelCreationContext, extract_constant_exprs: bool = False
):
self._ctx = ctx
self._fold_integers = True
self._fold_floats = False
self._extract_constant_exprs = True
self._extract_constant_exprs = extract_constant_exprs
def __call__(self, node: PsAstNode) -> PsAstNode:
return self.visit(node)
ecc = ECContext(self._ctx)
def visit(self, node: PsAstNode) -> PsAstNode:
node = self.visit(node, ecc)
if ecc.extractions:
prepend_decls = [
PsDeclaration(PsExpression.make(symb), expr)
for symb, expr in ecc.extractions
]
if not isinstance(node, PsBlock):
node = PsBlock(prepend_decls + [node])
else:
node.children = prepend_decls + list(node.children)
return node
def visit(self, node: PsAstNode, ecc: ECContext) -> PsAstNode:
match node:
case PsExpression():
transformed_expr, _ = self.visit_expr(node)
transformed_expr, _ = self.visit_expr(node, ecc)
return transformed_expr
case _:
node.children = [self.visit(c) for c in node.children]
node.children = [self.visit(c, ecc) for c in node.children]
return node
def visit_expr(self, expr: PsExpression) -> tuple[PsExpression, bool]:
def visit_expr(
self, expr: PsExpression, ecc: ECContext
) -> tuple[PsExpression, bool]:
"""Transformation of expressions.
Returns:
......@@ -66,13 +133,13 @@ class EliminateConstants:
# Return constants as they are
if isinstance(expr, PsConstantExpr):
return expr, True
# Shortcut symbols
if isinstance(expr, PsSymbolExpr):
return expr, False
subtree_results = [
self.visit_expr(cast(PsExpression, c)) for c in expr.children
self.visit_expr(cast(PsExpression, c), ecc) for c in expr.children
]
expr.children = [r[0] for r in subtree_results]
subtree_constness = [r[1] for r in subtree_results]
......@@ -91,7 +158,7 @@ class EliminateConstants:
# Additive idempotence: Subtraction from zero
case PsSub(PsConstantExpr(c), other_op) if c.value == 0:
other_transformed, is_const = self.visit_expr(-other_op)
other_transformed, is_const = self.visit_expr(-other_op, ecc)
return other_transformed, is_const
# Multiplicative idempotence: Multiplication with and division by one
......@@ -155,7 +222,14 @@ class EliminateConstants:
expr.operand1 = op1_transformed
expr.operand2 = op2_transformed
return expr, True
# end if: no constant expressions encountered
# end if: this expression is not constant
# If required, extract constant subexpressions
if self._extract_constant_exprs:
for i, (child, is_const) in enumerate(subtree_results):
if is_const and not isinstance(child, PsConstantExpr):
replacement = ecc.extract_expression(child)
expr.set_child(i, replacement)
# Any other expressions are not considered constant even if their arguments are
return expr, False
......@@ -82,8 +82,11 @@ def create_kernel(
kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
# Simplifying transformations
kernel_ast = cast(PsBlock, EliminateConstants(ctx)(kernel_ast))
kernel_ast = cast(PsBlock, EraseAnonymousStructTypes(ctx)(kernel_ast))
elim_constants = EliminateConstants(ctx, extract_constant_exprs=True)
kernel_ast = cast(PsBlock, elim_constants(kernel_ast))
erase_anons = EraseAnonymousStructTypes(ctx)
kernel_ast = cast(PsBlock, erase_anons(kernel_ast))
# 7. Apply optimizations
# - Vectorization
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment