From 2ab005a304fb3ed372001ce7bc2606ede7b7892a Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 4 Apr 2024 16:58:17 +0200 Subject: [PATCH] Add Branch Elimination Pass --- .../backend/kernelcreation/typification.py | 4 + .../backend/transformations/__init__.py | 2 + .../transformations/eliminate_branches.py | 73 +++++++++++++++++++ .../test_branch_elimination.py | 55 ++++++++++++++ 4 files changed, 134 insertions(+) create mode 100644 src/pystencils/backend/transformations/eliminate_branches.py create mode 100644 tests/nbackend/transformations/test_branch_elimination.py diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 301540c59..034675e62 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -22,6 +22,7 @@ from ..ast.structural import ( PsExpression, PsAssignment, PsDeclaration, + PsComment, ) from ..ast.expressions import ( PsArrayAccess, @@ -326,6 +327,9 @@ class Typifier: self.visit(body) + case PsComment(): + pass + case _: raise NotImplementedError(f"Can't typify {node}") diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index afb1e4fcd..01b695099 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -1,4 +1,5 @@ from .eliminate_constants import EliminateConstants +from .eliminate_branches import EliminateBranches from .canonicalize_symbols import CanonicalizeSymbols from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .erase_anonymous_structs import EraseAnonymousStructTypes @@ -7,6 +8,7 @@ from .select_intrinsics import MaterializeVectorIntrinsics __all__ = [ "EliminateConstants", + "EliminateBranches", "CanonicalizeSymbols", "HoistLoopInvariantDeclarations", "EraseAnonymousStructTypes", diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py new file mode 100644 index 000000000..eab3d3722 --- /dev/null +++ b/src/pystencils/backend/transformations/eliminate_branches.py @@ -0,0 +1,73 @@ +from ..kernelcreation import KernelCreationContext +from ..ast import PsAstNode +from ..ast.structural import PsLoop, PsBlock, PsConditional +from ..ast.expressions import PsConstantExpr + +from .eliminate_constants import EliminateConstants + +__all__ = ["EliminateBranches"] + + +class BranchElimContext: + def __init__(self) -> None: + self.enclosing_loops: list[PsLoop] = [] + + +class EliminateBranches: + """Replace conditional branches by their then- or else-branch if their condition can be unequivocally + evaluated. + + This pass will attempt to evaluate branch conditions within their context in the AST, and replace + conditionals by either their then- or their else-block if the branch is unequivocal. + + TODO: If islpy is installed, this pass will incorporate information about the iteration regions + of enclosing loops into its analysis. + """ + + def __init__(self, ctx: KernelCreationContext) -> None: + self._ctx = ctx + self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False) + + def __call__(self, node: PsAstNode) -> PsAstNode: + return self.visit(node, BranchElimContext()) + + def visit(self, node: PsAstNode, ec: BranchElimContext) -> PsAstNode: + match node: + case PsLoop(_, _, _, _, body): + ec.enclosing_loops.append(node) + self.visit(body, ec) + ec.enclosing_loops.pop() + + case PsBlock(statements): + statements_new: list[PsAstNode] = [] + for stmt in statements: + if isinstance(stmt, PsConditional): + result = self.handle_conditional(stmt, ec) + if result is not None: + statements_new.append(result) + else: + statements_new.append(self.visit(stmt, ec)) + node.statements = statements_new + + case PsConditional(): + result = self.handle_conditional(node, ec) + if result is None: + return PsBlock([]) + else: + return result + + return node + + def handle_conditional( + self, conditional: PsConditional, ec: BranchElimContext + ) -> PsConditional | PsBlock | None: + condition_simplified = self._elim_constants(conditional.condition) + match condition_simplified: + case PsConstantExpr(c) if c.value: + return conditional.branch_true + case PsConstantExpr(c) if not c.value: + return conditional.branch_false + + # TODO: Analyze condition against counters of enclosing loops using ISL + + return conditional diff --git a/tests/nbackend/transformations/test_branch_elimination.py b/tests/nbackend/transformations/test_branch_elimination.py new file mode 100644 index 000000000..0fb3526d0 --- /dev/null +++ b/tests/nbackend/transformations/test_branch_elimination.py @@ -0,0 +1,55 @@ +from pystencils import make_slice +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + Typifier, + AstFactory, +) +from pystencils.backend.ast.expressions import PsExpression +from pystencils.backend.ast.structural import PsConditional, PsBlock, PsComment +from pystencils.backend.constants import PsConstant +from pystencils.backend.transformations import EliminateBranches +from pystencils.types.quick import Int +from pystencils.backend.ast.expressions import PsGt + + +i0 = PsExpression.make(PsConstant(0, Int(32))) +i1 = PsExpression.make(PsConstant(1, Int(32))) + + +def test_eliminate_conditional(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + elim = EliminateBranches(ctx) + + b1 = PsBlock([PsComment("Branch One")]) + + b2 = PsBlock([PsComment("Branch Two")]) + + cond = typify(PsConditional(PsGt(i1, i0), b1, b2)) + result = elim(cond) + assert result == b1 + + cond = typify(PsConditional(PsGt(-i1, i0), b1, b2)) + result = elim(cond) + assert result == b2 + + cond = typify(PsConditional(PsGt(-i1, i0), b1)) + result = elim(cond) + assert result.structurally_equal(PsBlock([])) + + +def test_eliminate_nested_conditional(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + typify = Typifier(ctx) + elim = EliminateBranches(ctx) + + b1 = PsBlock([PsComment("Branch One")]) + + b2 = PsBlock([PsComment("Branch Two")]) + + cond = typify(PsConditional(PsGt(i1, i0), b1, b2)) + ast = factory.loop_nest(("i", "j"), make_slice[:10, :10], PsBlock([cond])) + + result = elim(ast) + assert result.body.statements[0].body.statements[0] == b1 -- GitLab