diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py index 5c306df0c7361ac868cae17327c1d95e8d20344e..f098d82df1ce6a748097756aa1616a72e57487b5 100644 --- a/src/pystencils/backend/transformations/eliminate_branches.py +++ b/src/pystencils/backend/transformations/eliminate_branches.py @@ -52,12 +52,12 @@ class EliminateBranches: of enclosing loops and enclosing conditionals into its analysis. Args: - no_isl (bool, optional): disable islpy based analysis + use_isl (bool, optional): enable islpy based analysis (default: True) """ - def __init__(self, ctx: KernelCreationContext, no_isl: bool = False) -> None: + def __init__(self, ctx: KernelCreationContext, use_isl: bool = True) -> None: self._ctx = ctx - self._no_isl = no_isl + self._use_isl = use_isl self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False) def __call__(self, node: PsAstNode) -> PsAstNode: @@ -104,8 +104,8 @@ class EliminateBranches: self, conditional: PsConditional, ec: BranchElimContext ) -> PsConditional | PsBlock | None: condition_simplified = self._elim_constants(conditional.condition) - if not self._no_isl: - condition_simplified = self._isl_symplify_condition( + if self._use_isl: + condition_simplified = self._isl_simplify_condition( condition_simplified, ec ) @@ -117,7 +117,7 @@ class EliminateBranches: return conditional - def _isl_symplify_condition( + def _isl_simplify_condition( self, condition: PsExpression, ec: BranchElimContext ) -> PsExpression: """If installed, use ISL to simplify the passed condition to true or @@ -127,102 +127,101 @@ class EliminateBranches: try: import islpy as isl + except ImportError: + return condition - def printer(expr: PsExpression): - match expr: - case PsSymbolExpr(symbol): - return symbol.name - - case PsConstantExpr(constant): - dtype = constant.get_dtype() - if not isinstance(dtype, (PsIntegerType, PsBoolType)): - raise IslAnalysisError( - "Only scalar integer and bool constant may appear in isl expressions." - ) - return str(constant.value) - - case PsAdd(op1, op2): - return f"({printer(op1)} + {printer(op2)})" - case PsSub(op1, op2): - return f"({printer(op1)} - {printer(op2)})" - case PsMul(op1, op2): - return f"({printer(op1)} * {printer(op2)})" - case PsDiv(op1, op2) | PsIntDiv(op1, op2): - return f"({printer(op1)} / {printer(op2)})" - case PsAnd(op1, op2): - return f"({printer(op1)} and {printer(op2)})" - case PsOr(op1, op2): - return f"({printer(op1)} or {printer(op2)})" - case PsEq(op1, op2): - return f"({printer(op1)} = {printer(op2)})" - case PsNe(op1, op2): - return f"({printer(op1)} != {printer(op2)})" - case PsGt(op1, op2): - return f"({printer(op1)} > {printer(op2)})" - case PsGe(op1, op2): - return f"({printer(op1)} >= {printer(op2)})" - case PsLt(op1, op2): - return f"({printer(op1)} < {printer(op2)})" - case PsLe(op1, op2): - return f"({printer(op1)} <= {printer(op2)})" - - case PsNeg(operand): - return f"(-{printer(operand)})" - case PsNot(operand): - return f"(not {printer(operand)})" - - case PsCast(_, operand): - return printer(operand) + def printer(expr: PsExpression): + match expr: + case PsSymbolExpr(symbol): + return symbol.name - case _: + case PsConstantExpr(constant): + dtype = constant.get_dtype() + if not isinstance(dtype, (PsIntegerType, PsBoolType)): raise IslAnalysisError( - f"Not supported by isl or don't know how to print {expr}" + "Only scalar integer and bool constant may appear in isl expressions." ) - - dofs = collect_undefined_symbols(condition) - outer_conditions = [] - - for loop in ec.enclosing_loops: - if not ( - isinstance(loop.step, PsConstantExpr) - and loop.step.constant.value == 1 - ): + return str(constant.value) + + case PsAdd(op1, op2): + return f"({printer(op1)} + {printer(op2)})" + case PsSub(op1, op2): + return f"({printer(op1)} - {printer(op2)})" + case PsMul(op1, op2): + return f"({printer(op1)} * {printer(op2)})" + case PsDiv(op1, op2) | PsIntDiv(op1, op2): + return f"({printer(op1)} / {printer(op2)})" + case PsAnd(op1, op2): + return f"({printer(op1)} and {printer(op2)})" + case PsOr(op1, op2): + return f"({printer(op1)} or {printer(op2)})" + case PsEq(op1, op2): + return f"({printer(op1)} = {printer(op2)})" + case PsNe(op1, op2): + return f"({printer(op1)} != {printer(op2)})" + case PsGt(op1, op2): + return f"({printer(op1)} > {printer(op2)})" + case PsGe(op1, op2): + return f"({printer(op1)} >= {printer(op2)})" + case PsLt(op1, op2): + return f"({printer(op1)} < {printer(op2)})" + case PsLe(op1, op2): + return f"({printer(op1)} <= {printer(op2)})" + + case PsNeg(operand): + return f"(-{printer(operand)})" + case PsNot(operand): + return f"(not {printer(operand)})" + + case PsCast(_, operand): + return printer(operand) + + case _: raise IslAnalysisError( - "Loops with strides != 1 are not yet supported." + f"Not supported by isl or don't know how to print {expr}" ) - dofs.add(loop.counter.symbol) - dofs.update(collect_undefined_symbols(loop.start)) - dofs.update(collect_undefined_symbols(loop.stop)) + dofs = collect_undefined_symbols(condition) + outer_conditions = [] - loop_start_str = printer(loop.start) - loop_stop_str = printer(loop.stop) - ctr_name = loop.counter.symbol.name - outer_conditions.append( - f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}" + for loop in ec.enclosing_loops: + if not ( + isinstance(loop.step, PsConstantExpr) + and loop.step.constant.value == 1 + ): + raise IslAnalysisError( + "Loops with strides != 1 are not yet supported." ) - for cond in ec.enclosing_conditions: - dofs.update(collect_undefined_symbols(cond)) - outer_conditions.append(printer(cond)) + dofs.add(loop.counter.symbol) + dofs.update(collect_undefined_symbols(loop.start)) + dofs.update(collect_undefined_symbols(loop.stop)) + + loop_start_str = printer(loop.start) + loop_stop_str = printer(loop.stop) + ctr_name = loop.counter.symbol.name + outer_conditions.append( + f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}" + ) - dofs_str = ",".join(dof.name for dof in dofs) - outer_conditions_str = " and ".join(outer_conditions) - condition_str = printer(condition) + for cond in ec.enclosing_conditions: + dofs.update(collect_undefined_symbols(cond)) + outer_conditions.append(printer(cond)) - outer_set = isl.BasicSet(f"{{ [{dofs_str}] : {outer_conditions_str} }}") - inner_set = isl.BasicSet(f"{{ [{dofs_str}] : {condition_str} }}") + dofs_str = ",".join(dof.name for dof in dofs) + outer_conditions_str = " and ".join(outer_conditions) + condition_str = printer(condition) - if inner_set.is_empty(): - return PsExpression.make(PsConstant(False)) + outer_set = isl.BasicSet(f"{{ [{dofs_str}] : {outer_conditions_str} }}") + inner_set = isl.BasicSet(f"{{ [{dofs_str}] : {condition_str} }}") - intersection = outer_set.intersect(inner_set) - if intersection.is_empty(): - return PsExpression.make(PsConstant(False)) - elif intersection == outer_set: - return PsExpression.make(PsConstant(True)) - else: - return condition + if inner_set.is_empty(): + return PsExpression.make(PsConstant(False)) - except ImportError: + intersection = outer_set.intersect(inner_set) + if intersection.is_empty(): + return PsExpression.make(PsConstant(False)) + elif intersection == outer_set: + return PsExpression.make(PsConstant(True)) + else: return condition