Skip to content
Snippets Groups Projects
Commit 9dbacce2 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Add elimination of trivial comparisons

parent a6ec04ed
No related branches found
No related tags found
1 merge request!375Support for Boolean Operations and Relations
Pipeline #64918 passed
......@@ -431,7 +431,7 @@ class PsUnOp(PsExpression):
@property
def python_operator(self) -> None | Callable[[Any], Any]:
return None
def __repr__(self) -> str:
opname = self.__class__.__name__
return f"{opname}({repr(self._operand)})"
......
......@@ -19,7 +19,13 @@ from ..ast.expressions import (
PsRel,
PsNeg,
PsNot,
PsCall
PsCall,
PsEq,
PsGe,
PsLe,
PsLt,
PsGt,
PsNe,
)
from ..ast.util import AstEqWrapper
......@@ -43,8 +49,6 @@ class ECContext:
self._ctx = ctx
self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict()
self._typifier = Typifier(ctx)
from ..emission import CAstPrinter
self._printer = CAstPrinter(0)
......@@ -72,7 +76,7 @@ class ECContext:
return f"__c_{code}"
def extract_expression(self, expr: PsExpression) -> PsSymbolExpr:
expr, dtype = self._typifier.typify_expression(expr)
dtype = expr.get_dtype()
expr_wrapped = AstEqWrapper(expr)
if expr_wrapped not in self._extracted_constants:
......@@ -105,6 +109,7 @@ class EliminateConstants:
self, ctx: KernelCreationContext, extract_constant_exprs: bool = False
):
self._ctx = ctx
self._typifer = Typifier(ctx)
self._fold_integers = True
self._fold_relations = True
......@@ -160,7 +165,7 @@ class EliminateConstants:
expr.children = [r[0] for r in subtree_results]
subtree_constness = [r[1] for r in subtree_results]
# Eliminate idempotence and dominance
# Eliminate idempotence, dominance, and trivial relations
match expr:
# Additive idempotence: Addition and subtraction of zero
case PsAdd(PsConstantExpr(c), other_op) if c.value == 0:
......@@ -220,6 +225,19 @@ class EliminateConstants:
case PsOr(other_op, PsConstantExpr(c)) if c.value:
return PsConstantExpr(c), True
# Trivial comparisons
case (
PsEq(op1, op2) | PsGe(op1, op2) | PsLe(op1, op2)
) if op1.structurally_equal(op2):
true = self._typifer(PsConstantExpr(PsConstant(True, PsBoolType())))
return true, True
case (
PsNe(op1, op2) | PsGt(op1, op2) | PsLt(op1, op2)
) if op1.structurally_equal(op2):
false = self._typifer(PsConstantExpr(PsConstant(False, PsBoolType())))
return false, True
# end match: no idempotence or dominance encountered
# Detect constant expressions
......@@ -239,6 +257,8 @@ class EliminateConstants:
or (self._fold_relations and is_rel)
)
folded: PsConstant | None
match expr:
case PsNeg(operand) | PsNot(operand):
if isinstance(operand, PsConstantExpr):
......@@ -252,7 +272,9 @@ class EliminateConstants:
return expr, True
case PsBinOp(op1, op2):
if isinstance(op1, PsConstantExpr) and isinstance(op2, PsConstantExpr):
if isinstance(op1, PsConstantExpr) and isinstance(
op2, PsConstantExpr
):
v1 = op1.constant.value
v2 = op2.constant.value
......@@ -277,7 +299,7 @@ class EliminateConstants:
return PsConstantExpr(folded), True
return expr, True
case PsCall(PsMathFunction(), _):
# TODO: Some math functions (min/max) might be safely folded
return expr, True
......
......@@ -120,3 +120,11 @@ def test_relations_folding():
expr = typify(PsEq(- i1, - i3))
result = elim(expr)
assert result.structurally_equal(false)
expr = typify(PsEq(x + y, f1 * (x + y)))
result = elim(expr)
assert result.structurally_equal(true)
expr = typify(PsGt(x + y, f1 * (x + y)))
result = elim(expr)
assert result.structurally_equal(false)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment