Skip to content
Snippets Groups Projects

Eliminate branches: implement isl analysis and recurse into conditionals

Merged Daniel Bauer requested to merge hyteg/pystencils:bauerd/isl into backend-rework
All threads resolved!
@@ -52,12 +52,12 @@ class EliminateBranches:
of enclosing loops and enclosing conditionals into its analysis.
Args:
no_isl (bool, optional): disable islpy based analysis
use_isl (bool, optional): enable islpy based analysis (default: True)
"""
def __init__(self, ctx: KernelCreationContext, no_isl: bool = False) -> None:
def __init__(self, ctx: KernelCreationContext, use_isl: bool = True) -> None:
self._ctx = ctx
self._no_isl = no_isl
self._use_isl = use_isl
self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False)
def __call__(self, node: PsAstNode) -> PsAstNode:
@@ -73,12 +73,7 @@ class EliminateBranches:
case PsBlock(statements):
statements_new: list[PsAstNode] = []
for stmt in statements:
if isinstance(stmt, PsConditional):
simplified = self.visit(stmt, ec)
if not simplified.structurally_equal(PsBlock([])):
statements_new.append(simplified)
else:
statements_new.append(self.visit(stmt, ec))
statements_new.append(self.visit(stmt, ec))
node.statements = statements_new
case PsConditional():
@@ -109,8 +104,8 @@ class EliminateBranches:
self, conditional: PsConditional, ec: BranchElimContext
) -> PsConditional | PsBlock | None:
condition_simplified = self._elim_constants(conditional.condition)
if not self._no_isl:
condition_simplified = self._isl_symplify_condition(
if self._use_isl:
condition_simplified = self._isl_simplify_condition(
condition_simplified, ec
)
@@ -122,7 +117,7 @@ class EliminateBranches:
return conditional
def _isl_symplify_condition(
def _isl_simplify_condition(
self, condition: PsExpression, ec: BranchElimContext
) -> PsExpression:
"""If installed, use ISL to simplify the passed condition to true or
@@ -132,102 +127,101 @@ class EliminateBranches:
try:
import islpy as isl
except ImportError:
return condition
def printer(expr: PsExpression):
match expr:
case PsSymbolExpr(symbol):
return symbol.name
case PsConstantExpr(constant):
dtype = constant.get_dtype()
if not isinstance(dtype, (PsIntegerType, PsBoolType)):
raise IslAnalysisError(
"Only scalar integer and bool constant may appear in isl expressions."
)
return str(constant.value)
case PsAdd(op1, op2):
return f"({printer(op1)} + {printer(op2)})"
case PsSub(op1, op2):
return f"({printer(op1)} - {printer(op2)})"
case PsMul(op1, op2):
return f"({printer(op1)} * {printer(op2)})"
case PsDiv(op1, op2) | PsIntDiv(op1, op2):
return f"({printer(op1)} / {printer(op2)})"
case PsAnd(op1, op2):
return f"({printer(op1)} and {printer(op2)})"
case PsOr(op1, op2):
return f"({printer(op1)} or {printer(op2)})"
case PsEq(op1, op2):
return f"({printer(op1)} = {printer(op2)})"
case PsNe(op1, op2):
return f"({printer(op1)} != {printer(op2)})"
case PsGt(op1, op2):
return f"({printer(op1)} > {printer(op2)})"
case PsGe(op1, op2):
return f"({printer(op1)} >= {printer(op2)})"
case PsLt(op1, op2):
return f"({printer(op1)} < {printer(op2)})"
case PsLe(op1, op2):
return f"({printer(op1)} <= {printer(op2)})"
case PsNeg(operand):
return f"(-{printer(operand)})"
case PsNot(operand):
return f"(not {printer(operand)})"
case PsCast(_, operand):
return printer(operand)
def printer(expr: PsExpression):
match expr:
case PsSymbolExpr(symbol):
return symbol.name
case _:
case PsConstantExpr(constant):
dtype = constant.get_dtype()
if not isinstance(dtype, (PsIntegerType, PsBoolType)):
raise IslAnalysisError(
f"Not supported by isl or don't know how to print {expr}"
"Only scalar integer and bool constant may appear in isl expressions."
)
dofs = collect_undefined_symbols(condition)
outer_conditions = []
for loop in ec.enclosing_loops:
if not (
isinstance(loop.step, PsConstantExpr)
and loop.step.constant.value == 1
):
return str(constant.value)
case PsAdd(op1, op2):
return f"({printer(op1)} + {printer(op2)})"
case PsSub(op1, op2):
return f"({printer(op1)} - {printer(op2)})"
case PsMul(op1, op2):
return f"({printer(op1)} * {printer(op2)})"
case PsDiv(op1, op2) | PsIntDiv(op1, op2):
return f"({printer(op1)} / {printer(op2)})"
case PsAnd(op1, op2):
return f"({printer(op1)} and {printer(op2)})"
case PsOr(op1, op2):
return f"({printer(op1)} or {printer(op2)})"
case PsEq(op1, op2):
return f"({printer(op1)} = {printer(op2)})"
case PsNe(op1, op2):
return f"({printer(op1)} != {printer(op2)})"
case PsGt(op1, op2):
return f"({printer(op1)} > {printer(op2)})"
case PsGe(op1, op2):
return f"({printer(op1)} >= {printer(op2)})"
case PsLt(op1, op2):
return f"({printer(op1)} < {printer(op2)})"
case PsLe(op1, op2):
return f"({printer(op1)} <= {printer(op2)})"
case PsNeg(operand):
return f"(-{printer(operand)})"
case PsNot(operand):
return f"(not {printer(operand)})"
case PsCast(_, operand):
return printer(operand)
case _:
raise IslAnalysisError(
"Loops with strides != 1 are not yet supported."
f"Not supported by isl or don't know how to print {expr}"
)
dofs.add(loop.counter.symbol)
dofs.update(collect_undefined_symbols(loop.start))
dofs.update(collect_undefined_symbols(loop.stop))
dofs = collect_undefined_symbols(condition)
outer_conditions = []
loop_start_str = printer(loop.start)
loop_stop_str = printer(loop.stop)
ctr_name = loop.counter.symbol.name
outer_conditions.append(
f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}"
for loop in ec.enclosing_loops:
if not (
isinstance(loop.step, PsConstantExpr)
and loop.step.constant.value == 1
):
raise IslAnalysisError(
"Loops with strides != 1 are not yet supported."
)
for cond in ec.enclosing_conditions:
dofs.update(collect_undefined_symbols(cond))
outer_conditions.append(printer(cond))
dofs.add(loop.counter.symbol)
dofs.update(collect_undefined_symbols(loop.start))
dofs.update(collect_undefined_symbols(loop.stop))
loop_start_str = printer(loop.start)
loop_stop_str = printer(loop.stop)
ctr_name = loop.counter.symbol.name
outer_conditions.append(
f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}"
)
dofs_str = ",".join(dof.name for dof in dofs)
outer_conditions_str = " and ".join(outer_conditions)
condition_str = printer(condition)
for cond in ec.enclosing_conditions:
dofs.update(collect_undefined_symbols(cond))
outer_conditions.append(printer(cond))
outer_set = isl.BasicSet(f"{{ [{dofs_str}] : {outer_conditions_str} }}")
inner_set = isl.BasicSet(f"{{ [{dofs_str}] : {condition_str} }}")
dofs_str = ",".join(dof.name for dof in dofs)
outer_conditions_str = " and ".join(outer_conditions)
condition_str = printer(condition)
if inner_set.is_empty():
return PsExpression.make(PsConstant(False))
outer_set = isl.BasicSet(f"{{ [{dofs_str}] : {outer_conditions_str} }}")
inner_set = isl.BasicSet(f"{{ [{dofs_str}] : {condition_str} }}")
intersection = outer_set.intersect(inner_set)
if intersection.is_empty():
return PsExpression.make(PsConstant(False))
elif intersection == outer_set:
return PsExpression.make(PsConstant(True))
else:
return condition
if inner_set.is_empty():
return PsExpression.make(PsConstant(False))
except ImportError:
intersection = outer_set.intersect(inner_set)
if intersection.is_empty():
return PsExpression.make(PsConstant(False))
elif intersection == outer_set:
return PsExpression.make(PsConstant(True))
else:
return condition