Skip to content
Snippets Groups Projects
Commit aebb23ee authored by Daniel Bauer's avatar Daniel Bauer :speech_balloon: Committed by Frederik Hennig
Browse files

Eliminate branches: implement isl analysis and recurse into conditionals

parent 02965644
No related branches found
No related tags found
1 merge request!390Eliminate branches: implement isl analysis and recurse into conditionals
......@@ -16,3 +16,6 @@ ignore_missing_imports=true
[mypy-appdirs.*]
ignore_missing_imports=true
[mypy-islpy.*]
ignore_missing_imports=true
from ..kernelcreation import KernelCreationContext
from ..ast import PsAstNode
from ..ast.analysis import collect_undefined_symbols
from ..ast.structural import PsLoop, PsBlock, PsConditional
from ..ast.expressions import PsConstantExpr
from ..ast.expressions import (
PsAnd,
PsCast,
PsConstant,
PsConstantExpr,
PsDiv,
PsEq,
PsExpression,
PsGe,
PsGt,
PsIntDiv,
PsLe,
PsLt,
PsMul,
PsNe,
PsNeg,
PsNot,
PsOr,
PsSub,
PsSymbolExpr,
PsAdd,
)
from .eliminate_constants import EliminateConstants
from ...types import PsBoolType, PsIntegerType
__all__ = ["EliminateBranches"]
class IslAnalysisError(Exception):
"""Indicates a fatal error during integer set analysis (based on islpy)"""
class BranchElimContext:
def __init__(self) -> None:
self.enclosing_loops: list[PsLoop] = []
self.enclosing_conditions: list[PsExpression] = []
class EliminateBranches:
......@@ -20,12 +48,16 @@ class EliminateBranches:
This pass will attempt to evaluate branch conditions within their context in the AST, and replace
conditionals by either their then- or their else-block if the branch is unequivocal.
TODO: If islpy is installed, this pass will incorporate information about the iteration regions
of enclosing loops into its analysis.
If islpy is installed, this pass will incorporate information about the iteration regions
of enclosing loops and enclosing conditionals into its analysis.
Args:
use_isl (bool, optional): enable islpy based analysis (default: True)
"""
def __init__(self, ctx: KernelCreationContext) -> None:
def __init__(self, ctx: KernelCreationContext, use_isl: bool = True) -> None:
self._ctx = ctx
self._use_isl = use_isl
self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False)
def __call__(self, node: PsAstNode) -> PsAstNode:
......@@ -41,20 +73,30 @@ class EliminateBranches:
case PsBlock(statements):
statements_new: list[PsAstNode] = []
for stmt in statements:
if isinstance(stmt, PsConditional):
result = self.handle_conditional(stmt, ec)
if result is not None:
statements_new.append(result)
else:
statements_new.append(self.visit(stmt, ec))
statements_new.append(self.visit(stmt, ec))
node.statements = statements_new
case PsConditional():
result = self.handle_conditional(node, ec)
if result is None:
return PsBlock([])
else:
return result
match result:
case PsConditional(_, branch_true, branch_false):
ec.enclosing_conditions.append(result.condition)
self.visit(branch_true, ec)
ec.enclosing_conditions.pop()
if branch_false is not None:
ec.enclosing_conditions.append(PsNot(result.condition))
self.visit(branch_false, ec)
ec.enclosing_conditions.pop()
case PsBlock():
self.visit(result, ec)
case None:
result = PsBlock([])
case _:
assert False, "unreachable code"
return result
return node
......@@ -62,12 +104,124 @@ class EliminateBranches:
self, conditional: PsConditional, ec: BranchElimContext
) -> PsConditional | PsBlock | None:
condition_simplified = self._elim_constants(conditional.condition)
if self._use_isl:
condition_simplified = self._isl_simplify_condition(
condition_simplified, ec
)
match condition_simplified:
case PsConstantExpr(c) if c.value:
return conditional.branch_true
case PsConstantExpr(c) if not c.value:
return conditional.branch_false
# TODO: Analyze condition against counters of enclosing loops using ISL
return conditional
def _isl_simplify_condition(
self, condition: PsExpression, ec: BranchElimContext
) -> PsExpression:
"""If installed, use ISL to simplify the passed condition to true or
false based on enclosing loops and conditionals. If no simplification
can be made or ISL is not installed, the original condition is returned.
"""
try:
import islpy as isl
except ImportError:
return condition
def printer(expr: PsExpression):
match expr:
case PsSymbolExpr(symbol):
return symbol.name
case PsConstantExpr(constant):
dtype = constant.get_dtype()
if not isinstance(dtype, (PsIntegerType, PsBoolType)):
raise IslAnalysisError(
"Only scalar integer and bool constant may appear in isl expressions."
)
return str(constant.value)
case PsAdd(op1, op2):
return f"({printer(op1)} + {printer(op2)})"
case PsSub(op1, op2):
return f"({printer(op1)} - {printer(op2)})"
case PsMul(op1, op2):
return f"({printer(op1)} * {printer(op2)})"
case PsDiv(op1, op2) | PsIntDiv(op1, op2):
return f"({printer(op1)} / {printer(op2)})"
case PsAnd(op1, op2):
return f"({printer(op1)} and {printer(op2)})"
case PsOr(op1, op2):
return f"({printer(op1)} or {printer(op2)})"
case PsEq(op1, op2):
return f"({printer(op1)} = {printer(op2)})"
case PsNe(op1, op2):
return f"({printer(op1)} != {printer(op2)})"
case PsGt(op1, op2):
return f"({printer(op1)} > {printer(op2)})"
case PsGe(op1, op2):
return f"({printer(op1)} >= {printer(op2)})"
case PsLt(op1, op2):
return f"({printer(op1)} < {printer(op2)})"
case PsLe(op1, op2):
return f"({printer(op1)} <= {printer(op2)})"
case PsNeg(operand):
return f"(-{printer(operand)})"
case PsNot(operand):
return f"(not {printer(operand)})"
case PsCast(_, operand):
return printer(operand)
case _:
raise IslAnalysisError(
f"Not supported by isl or don't know how to print {expr}"
)
dofs = collect_undefined_symbols(condition)
outer_conditions = []
for loop in ec.enclosing_loops:
if not (
isinstance(loop.step, PsConstantExpr)
and loop.step.constant.value == 1
):
raise IslAnalysisError(
"Loops with strides != 1 are not yet supported."
)
dofs.add(loop.counter.symbol)
dofs.update(collect_undefined_symbols(loop.start))
dofs.update(collect_undefined_symbols(loop.stop))
loop_start_str = printer(loop.start)
loop_stop_str = printer(loop.stop)
ctr_name = loop.counter.symbol.name
outer_conditions.append(
f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}"
)
for cond in ec.enclosing_conditions:
dofs.update(collect_undefined_symbols(cond))
outer_conditions.append(printer(cond))
dofs_str = ",".join(dof.name for dof in dofs)
outer_conditions_str = " and ".join(outer_conditions)
condition_str = printer(condition)
outer_set = isl.BasicSet(f"{{ [{dofs_str}] : {outer_conditions_str} }}")
inner_set = isl.BasicSet(f"{{ [{dofs_str}] : {condition_str} }}")
if inner_set.is_empty():
return PsExpression.make(PsConstant(False))
intersection = outer_set.intersect(inner_set)
if intersection.is_empty():
return PsExpression.make(PsConstant(False))
elif intersection == outer_set:
return PsExpression.make(PsConstant(True))
else:
return condition
......@@ -4,12 +4,18 @@ from pystencils.backend.kernelcreation import (
Typifier,
AstFactory,
)
from pystencils.backend.ast.expressions import PsExpression
from pystencils.backend.ast.expressions import (
PsExpression,
PsEq,
PsGe,
PsGt,
PsLe,
PsLt,
)
from pystencils.backend.ast.structural import PsConditional, PsBlock, PsComment
from pystencils.backend.constants import PsConstant
from pystencils.backend.transformations import EliminateBranches
from pystencils.types.quick import Int
from pystencils.backend.ast.expressions import PsGt
i0 = PsExpression.make(PsConstant(0, Int(32)))
......@@ -53,3 +59,39 @@ def test_eliminate_nested_conditional():
result = elim(ast)
assert result.body.statements[0].body.statements[0] == b1
def test_isl():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
typify = Typifier(ctx)
elim = EliminateBranches(ctx)
i = PsExpression.make(ctx.get_symbol("i", ctx.index_dtype))
j = PsExpression.make(ctx.get_symbol("j", ctx.index_dtype))
const_2 = PsExpression.make(PsConstant(2, ctx.index_dtype))
const_4 = PsExpression.make(PsConstant(4, ctx.index_dtype))
a_true = PsBlock([PsComment("a true")])
a_false = PsBlock([PsComment("a false")])
b_true = PsBlock([PsComment("b true")])
b_false = PsBlock([PsComment("b false")])
c_true = PsBlock([PsComment("c true")])
c_false = PsBlock([PsComment("c false")])
a = PsConditional(PsLt(i + j, const_2 * const_4), a_true, a_false)
b = PsConditional(PsGe(j, const_4), b_true, b_false)
c = PsConditional(PsEq(i, const_4), c_true, c_false)
outer_loop = factory.loop(j.symbol.name, slice(0, 3), PsBlock([a, b, c]))
outer_cond = typify(
PsConditional(PsLe(i, const_4), PsBlock([outer_loop]), PsBlock([]))
)
ast = outer_cond
result = elim(ast)
assert result.branch_true.statements[0].body.statements[0] == a_true
assert result.branch_true.statements[0].body.statements[1] == b_false
assert result.branch_true.statements[0].body.statements[2] == c
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment