Skip to content
Snippets Groups Projects
Commit 2ab005a3 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Add Branch Elimination Pass

parent 9dbacce2
Branches
No related tags found
1 merge request!375Support for Boolean Operations and Relations
Pipeline #64919 passed
...@@ -22,6 +22,7 @@ from ..ast.structural import ( ...@@ -22,6 +22,7 @@ from ..ast.structural import (
PsExpression, PsExpression,
PsAssignment, PsAssignment,
PsDeclaration, PsDeclaration,
PsComment,
) )
from ..ast.expressions import ( from ..ast.expressions import (
PsArrayAccess, PsArrayAccess,
...@@ -326,6 +327,9 @@ class Typifier: ...@@ -326,6 +327,9 @@ class Typifier:
self.visit(body) self.visit(body)
case PsComment():
pass
case _: case _:
raise NotImplementedError(f"Can't typify {node}") raise NotImplementedError(f"Can't typify {node}")
......
from .eliminate_constants import EliminateConstants from .eliminate_constants import EliminateConstants
from .eliminate_branches import EliminateBranches
from .canonicalize_symbols import CanonicalizeSymbols from .canonicalize_symbols import CanonicalizeSymbols
from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations
from .erase_anonymous_structs import EraseAnonymousStructTypes from .erase_anonymous_structs import EraseAnonymousStructTypes
...@@ -7,6 +8,7 @@ from .select_intrinsics import MaterializeVectorIntrinsics ...@@ -7,6 +8,7 @@ from .select_intrinsics import MaterializeVectorIntrinsics
__all__ = [ __all__ = [
"EliminateConstants", "EliminateConstants",
"EliminateBranches",
"CanonicalizeSymbols", "CanonicalizeSymbols",
"HoistLoopInvariantDeclarations", "HoistLoopInvariantDeclarations",
"EraseAnonymousStructTypes", "EraseAnonymousStructTypes",
......
from ..kernelcreation import KernelCreationContext
from ..ast import PsAstNode
from ..ast.structural import PsLoop, PsBlock, PsConditional
from ..ast.expressions import PsConstantExpr
from .eliminate_constants import EliminateConstants
__all__ = ["EliminateBranches"]
class BranchElimContext:
def __init__(self) -> None:
self.enclosing_loops: list[PsLoop] = []
class EliminateBranches:
"""Replace conditional branches by their then- or else-branch if their condition can be unequivocally
evaluated.
This pass will attempt to evaluate branch conditions within their context in the AST, and replace
conditionals by either their then- or their else-block if the branch is unequivocal.
TODO: If islpy is installed, this pass will incorporate information about the iteration regions
of enclosing loops into its analysis.
"""
def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx
self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False)
def __call__(self, node: PsAstNode) -> PsAstNode:
return self.visit(node, BranchElimContext())
def visit(self, node: PsAstNode, ec: BranchElimContext) -> PsAstNode:
match node:
case PsLoop(_, _, _, _, body):
ec.enclosing_loops.append(node)
self.visit(body, ec)
ec.enclosing_loops.pop()
case PsBlock(statements):
statements_new: list[PsAstNode] = []
for stmt in statements:
if isinstance(stmt, PsConditional):
result = self.handle_conditional(stmt, ec)
if result is not None:
statements_new.append(result)
else:
statements_new.append(self.visit(stmt, ec))
node.statements = statements_new
case PsConditional():
result = self.handle_conditional(node, ec)
if result is None:
return PsBlock([])
else:
return result
return node
def handle_conditional(
self, conditional: PsConditional, ec: BranchElimContext
) -> PsConditional | PsBlock | None:
condition_simplified = self._elim_constants(conditional.condition)
match condition_simplified:
case PsConstantExpr(c) if c.value:
return conditional.branch_true
case PsConstantExpr(c) if not c.value:
return conditional.branch_false
# TODO: Analyze condition against counters of enclosing loops using ISL
return conditional
from pystencils import make_slice
from pystencils.backend.kernelcreation import (
KernelCreationContext,
Typifier,
AstFactory,
)
from pystencils.backend.ast.expressions import PsExpression
from pystencils.backend.ast.structural import PsConditional, PsBlock, PsComment
from pystencils.backend.constants import PsConstant
from pystencils.backend.transformations import EliminateBranches
from pystencils.types.quick import Int
from pystencils.backend.ast.expressions import PsGt
i0 = PsExpression.make(PsConstant(0, Int(32)))
i1 = PsExpression.make(PsConstant(1, Int(32)))
def test_eliminate_conditional():
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateBranches(ctx)
b1 = PsBlock([PsComment("Branch One")])
b2 = PsBlock([PsComment("Branch Two")])
cond = typify(PsConditional(PsGt(i1, i0), b1, b2))
result = elim(cond)
assert result == b1
cond = typify(PsConditional(PsGt(-i1, i0), b1, b2))
result = elim(cond)
assert result == b2
cond = typify(PsConditional(PsGt(-i1, i0), b1))
result = elim(cond)
assert result.structurally_equal(PsBlock([]))
def test_eliminate_nested_conditional():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
typify = Typifier(ctx)
elim = EliminateBranches(ctx)
b1 = PsBlock([PsComment("Branch One")])
b2 = PsBlock([PsComment("Branch Two")])
cond = typify(PsConditional(PsGt(i1, i0), b1, b2))
ast = factory.loop_nest(("i", "j"), make_slice[:10, :10], PsBlock([cond]))
result = elim(ast)
assert result.body.statements[0].body.statements[0] == b1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment