diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 301540c592e328ea2d185ec0685d81743bc5e488..034675e6209c4eee01efcfe5e35984e7140f1780 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -22,6 +22,7 @@ from ..ast.structural import ( PsExpression, PsAssignment, PsDeclaration, + PsComment, ) from ..ast.expressions import ( PsArrayAccess, @@ -326,6 +327,9 @@ class Typifier: self.visit(body) + case PsComment(): + pass + case _: raise NotImplementedError(f"Can't typify {node}") diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index afb1e4fcd52d2f0fd85a008ca24987f085fd7dc6..01b69509991eaa762a093f50f427f6e4050dc34a 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -1,4 +1,5 @@ from .eliminate_constants import EliminateConstants +from .eliminate_branches import EliminateBranches from .canonicalize_symbols import CanonicalizeSymbols from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .erase_anonymous_structs import EraseAnonymousStructTypes @@ -7,6 +8,7 @@ from .select_intrinsics import MaterializeVectorIntrinsics __all__ = [ "EliminateConstants", + "EliminateBranches", "CanonicalizeSymbols", "HoistLoopInvariantDeclarations", "EraseAnonymousStructTypes", diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py new file mode 100644 index 0000000000000000000000000000000000000000..eab3d3722c30756ab39af072e75e9d6d89874447 --- /dev/null +++ b/src/pystencils/backend/transformations/eliminate_branches.py @@ -0,0 +1,73 @@ +from ..kernelcreation import KernelCreationContext +from ..ast import PsAstNode +from ..ast.structural import PsLoop, PsBlock, PsConditional +from ..ast.expressions import PsConstantExpr + +from .eliminate_constants import EliminateConstants + +__all__ = ["EliminateBranches"] + + +class BranchElimContext: + def __init__(self) -> None: + self.enclosing_loops: list[PsLoop] = [] + + +class EliminateBranches: + """Replace conditional branches by their then- or else-branch if their condition can be unequivocally + evaluated. + + This pass will attempt to evaluate branch conditions within their context in the AST, and replace + conditionals by either their then- or their else-block if the branch is unequivocal. + + TODO: If islpy is installed, this pass will incorporate information about the iteration regions + of enclosing loops into its analysis. + """ + + def __init__(self, ctx: KernelCreationContext) -> None: + self._ctx = ctx + self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False) + + def __call__(self, node: PsAstNode) -> PsAstNode: + return self.visit(node, BranchElimContext()) + + def visit(self, node: PsAstNode, ec: BranchElimContext) -> PsAstNode: + match node: + case PsLoop(_, _, _, _, body): + ec.enclosing_loops.append(node) + self.visit(body, ec) + ec.enclosing_loops.pop() + + case PsBlock(statements): + statements_new: list[PsAstNode] = [] + for stmt in statements: + if isinstance(stmt, PsConditional): + result = self.handle_conditional(stmt, ec) + if result is not None: + statements_new.append(result) + else: + statements_new.append(self.visit(stmt, ec)) + node.statements = statements_new + + case PsConditional(): + result = self.handle_conditional(node, ec) + if result is None: + return PsBlock([]) + else: + return result + + return node + + def handle_conditional( + self, conditional: PsConditional, ec: BranchElimContext + ) -> PsConditional | PsBlock | None: + condition_simplified = self._elim_constants(conditional.condition) + match condition_simplified: + case PsConstantExpr(c) if c.value: + return conditional.branch_true + case PsConstantExpr(c) if not c.value: + return conditional.branch_false + + # TODO: Analyze condition against counters of enclosing loops using ISL + + return conditional diff --git a/tests/nbackend/transformations/test_branch_elimination.py b/tests/nbackend/transformations/test_branch_elimination.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb3526d0b53fd40972c4dfeb06cf3a614bc6c10 --- /dev/null +++ b/tests/nbackend/transformations/test_branch_elimination.py @@ -0,0 +1,55 @@ +from pystencils import make_slice +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + Typifier, + AstFactory, +) +from pystencils.backend.ast.expressions import PsExpression +from pystencils.backend.ast.structural import PsConditional, PsBlock, PsComment +from pystencils.backend.constants import PsConstant +from pystencils.backend.transformations import EliminateBranches +from pystencils.types.quick import Int +from pystencils.backend.ast.expressions import PsGt + + +i0 = PsExpression.make(PsConstant(0, Int(32))) +i1 = PsExpression.make(PsConstant(1, Int(32))) + + +def test_eliminate_conditional(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + elim = EliminateBranches(ctx) + + b1 = PsBlock([PsComment("Branch One")]) + + b2 = PsBlock([PsComment("Branch Two")]) + + cond = typify(PsConditional(PsGt(i1, i0), b1, b2)) + result = elim(cond) + assert result == b1 + + cond = typify(PsConditional(PsGt(-i1, i0), b1, b2)) + result = elim(cond) + assert result == b2 + + cond = typify(PsConditional(PsGt(-i1, i0), b1)) + result = elim(cond) + assert result.structurally_equal(PsBlock([])) + + +def test_eliminate_nested_conditional(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + typify = Typifier(ctx) + elim = EliminateBranches(ctx) + + b1 = PsBlock([PsComment("Branch One")]) + + b2 = PsBlock([PsComment("Branch Two")]) + + cond = typify(PsConditional(PsGt(i1, i0), b1, b2)) + ast = factory.loop_nest(("i", "j"), make_slice[:10, :10], PsBlock([cond])) + + result = elim(ast) + assert result.body.statements[0].body.statements[0] == b1