From 2ab005a304fb3ed372001ce7bc2606ede7b7892a Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 4 Apr 2024 16:58:17 +0200
Subject: [PATCH] Add Branch Elimination Pass

---
 .../backend/kernelcreation/typification.py    |  4 +
 .../backend/transformations/__init__.py       |  2 +
 .../transformations/eliminate_branches.py     | 73 +++++++++++++++++++
 .../test_branch_elimination.py                | 55 ++++++++++++++
 4 files changed, 134 insertions(+)
 create mode 100644 src/pystencils/backend/transformations/eliminate_branches.py
 create mode 100644 tests/nbackend/transformations/test_branch_elimination.py

diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 301540c59..034675e62 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -22,6 +22,7 @@ from ..ast.structural import (
     PsExpression,
     PsAssignment,
     PsDeclaration,
+    PsComment,
 )
 from ..ast.expressions import (
     PsArrayAccess,
@@ -326,6 +327,9 @@ class Typifier:
 
                 self.visit(body)
 
+            case PsComment():
+                pass
+
             case _:
                 raise NotImplementedError(f"Can't typify {node}")
 
diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py
index afb1e4fcd..01b695099 100644
--- a/src/pystencils/backend/transformations/__init__.py
+++ b/src/pystencils/backend/transformations/__init__.py
@@ -1,4 +1,5 @@
 from .eliminate_constants import EliminateConstants
+from .eliminate_branches import EliminateBranches
 from .canonicalize_symbols import CanonicalizeSymbols
 from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations
 from .erase_anonymous_structs import EraseAnonymousStructTypes
@@ -7,6 +8,7 @@ from .select_intrinsics import MaterializeVectorIntrinsics
 
 __all__ = [
     "EliminateConstants",
+    "EliminateBranches",
     "CanonicalizeSymbols",
     "HoistLoopInvariantDeclarations",
     "EraseAnonymousStructTypes",
diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py
new file mode 100644
index 000000000..eab3d3722
--- /dev/null
+++ b/src/pystencils/backend/transformations/eliminate_branches.py
@@ -0,0 +1,73 @@
+from ..kernelcreation import KernelCreationContext
+from ..ast import PsAstNode
+from ..ast.structural import PsLoop, PsBlock, PsConditional
+from ..ast.expressions import PsConstantExpr
+
+from .eliminate_constants import EliminateConstants
+
+__all__ = ["EliminateBranches"]
+
+
+class BranchElimContext:
+    def __init__(self) -> None:
+        self.enclosing_loops: list[PsLoop] = []
+
+
+class EliminateBranches:
+    """Replace conditional branches by their then- or else-branch if their condition can be unequivocally
+    evaluated.
+
+    This pass will attempt to evaluate branch conditions within their context in the AST, and replace
+    conditionals by either their then- or their else-block if the branch is unequivocal.
+
+    TODO: If islpy is installed, this pass will incorporate information about the iteration regions
+    of enclosing loops into its analysis.
+    """
+
+    def __init__(self, ctx: KernelCreationContext) -> None:
+        self._ctx = ctx
+        self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False)
+
+    def __call__(self, node: PsAstNode) -> PsAstNode:
+        return self.visit(node, BranchElimContext())
+
+    def visit(self, node: PsAstNode, ec: BranchElimContext) -> PsAstNode:
+        match node:
+            case PsLoop(_, _, _, _, body):
+                ec.enclosing_loops.append(node)
+                self.visit(body, ec)
+                ec.enclosing_loops.pop()
+
+            case PsBlock(statements):
+                statements_new: list[PsAstNode] = []
+                for stmt in statements:
+                    if isinstance(stmt, PsConditional):
+                        result = self.handle_conditional(stmt, ec)
+                        if result is not None:
+                            statements_new.append(result)
+                    else:
+                        statements_new.append(self.visit(stmt, ec))
+                node.statements = statements_new
+
+            case PsConditional():
+                result = self.handle_conditional(node, ec)
+                if result is None:
+                    return PsBlock([])
+                else:
+                    return result
+
+        return node
+
+    def handle_conditional(
+        self, conditional: PsConditional, ec: BranchElimContext
+    ) -> PsConditional | PsBlock | None:
+        condition_simplified = self._elim_constants(conditional.condition)
+        match condition_simplified:
+            case PsConstantExpr(c) if c.value:
+                return conditional.branch_true
+            case PsConstantExpr(c) if not c.value:
+                return conditional.branch_false
+
+        #   TODO: Analyze condition against counters of enclosing loops using ISL
+
+        return conditional
diff --git a/tests/nbackend/transformations/test_branch_elimination.py b/tests/nbackend/transformations/test_branch_elimination.py
new file mode 100644
index 000000000..0fb3526d0
--- /dev/null
+++ b/tests/nbackend/transformations/test_branch_elimination.py
@@ -0,0 +1,55 @@
+from pystencils import make_slice
+from pystencils.backend.kernelcreation import (
+    KernelCreationContext,
+    Typifier,
+    AstFactory,
+)
+from pystencils.backend.ast.expressions import PsExpression
+from pystencils.backend.ast.structural import PsConditional, PsBlock, PsComment
+from pystencils.backend.constants import PsConstant
+from pystencils.backend.transformations import EliminateBranches
+from pystencils.types.quick import Int
+from pystencils.backend.ast.expressions import PsGt
+
+
+i0 = PsExpression.make(PsConstant(0, Int(32)))
+i1 = PsExpression.make(PsConstant(1, Int(32)))
+
+
+def test_eliminate_conditional():
+    ctx = KernelCreationContext()
+    typify = Typifier(ctx)
+    elim = EliminateBranches(ctx)
+
+    b1 = PsBlock([PsComment("Branch One")])
+
+    b2 = PsBlock([PsComment("Branch Two")])
+
+    cond = typify(PsConditional(PsGt(i1, i0), b1, b2))
+    result = elim(cond)
+    assert result == b1
+
+    cond = typify(PsConditional(PsGt(-i1, i0), b1, b2))
+    result = elim(cond)
+    assert result == b2
+
+    cond = typify(PsConditional(PsGt(-i1, i0), b1))
+    result = elim(cond)
+    assert result.structurally_equal(PsBlock([]))
+
+
+def test_eliminate_nested_conditional():
+    ctx = KernelCreationContext()
+    factory = AstFactory(ctx)
+    typify = Typifier(ctx)
+    elim = EliminateBranches(ctx)
+
+    b1 = PsBlock([PsComment("Branch One")])
+
+    b2 = PsBlock([PsComment("Branch Two")])
+
+    cond = typify(PsConditional(PsGt(i1, i0), b1, b2))
+    ast = factory.loop_nest(("i", "j"), make_slice[:10, :10], PsBlock([cond]))
+
+    result = elim(ast)
+    assert result.body.statements[0].body.statements[0] == b1
-- 
GitLab