Skip to content
Snippets Groups Projects
Commit 9dbacce2 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Add elimination of trivial comparisons

parent a6ec04ed
No related branches found
No related tags found
1 merge request!375Support for Boolean Operations and Relations
Pipeline #64918 passed
...@@ -431,7 +431,7 @@ class PsUnOp(PsExpression): ...@@ -431,7 +431,7 @@ class PsUnOp(PsExpression):
@property @property
def python_operator(self) -> None | Callable[[Any], Any]: def python_operator(self) -> None | Callable[[Any], Any]:
return None return None
def __repr__(self) -> str: def __repr__(self) -> str:
opname = self.__class__.__name__ opname = self.__class__.__name__
return f"{opname}({repr(self._operand)})" return f"{opname}({repr(self._operand)})"
......
...@@ -19,7 +19,13 @@ from ..ast.expressions import ( ...@@ -19,7 +19,13 @@ from ..ast.expressions import (
PsRel, PsRel,
PsNeg, PsNeg,
PsNot, PsNot,
PsCall PsCall,
PsEq,
PsGe,
PsLe,
PsLt,
PsGt,
PsNe,
) )
from ..ast.util import AstEqWrapper from ..ast.util import AstEqWrapper
...@@ -43,8 +49,6 @@ class ECContext: ...@@ -43,8 +49,6 @@ class ECContext:
self._ctx = ctx self._ctx = ctx
self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict() self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict()
self._typifier = Typifier(ctx)
from ..emission import CAstPrinter from ..emission import CAstPrinter
self._printer = CAstPrinter(0) self._printer = CAstPrinter(0)
...@@ -72,7 +76,7 @@ class ECContext: ...@@ -72,7 +76,7 @@ class ECContext:
return f"__c_{code}" return f"__c_{code}"
def extract_expression(self, expr: PsExpression) -> PsSymbolExpr: def extract_expression(self, expr: PsExpression) -> PsSymbolExpr:
expr, dtype = self._typifier.typify_expression(expr) dtype = expr.get_dtype()
expr_wrapped = AstEqWrapper(expr) expr_wrapped = AstEqWrapper(expr)
if expr_wrapped not in self._extracted_constants: if expr_wrapped not in self._extracted_constants:
...@@ -105,6 +109,7 @@ class EliminateConstants: ...@@ -105,6 +109,7 @@ class EliminateConstants:
self, ctx: KernelCreationContext, extract_constant_exprs: bool = False self, ctx: KernelCreationContext, extract_constant_exprs: bool = False
): ):
self._ctx = ctx self._ctx = ctx
self._typifer = Typifier(ctx)
self._fold_integers = True self._fold_integers = True
self._fold_relations = True self._fold_relations = True
...@@ -160,7 +165,7 @@ class EliminateConstants: ...@@ -160,7 +165,7 @@ class EliminateConstants:
expr.children = [r[0] for r in subtree_results] expr.children = [r[0] for r in subtree_results]
subtree_constness = [r[1] for r in subtree_results] subtree_constness = [r[1] for r in subtree_results]
# Eliminate idempotence and dominance # Eliminate idempotence, dominance, and trivial relations
match expr: match expr:
# Additive idempotence: Addition and subtraction of zero # Additive idempotence: Addition and subtraction of zero
case PsAdd(PsConstantExpr(c), other_op) if c.value == 0: case PsAdd(PsConstantExpr(c), other_op) if c.value == 0:
...@@ -220,6 +225,19 @@ class EliminateConstants: ...@@ -220,6 +225,19 @@ class EliminateConstants:
case PsOr(other_op, PsConstantExpr(c)) if c.value: case PsOr(other_op, PsConstantExpr(c)) if c.value:
return PsConstantExpr(c), True return PsConstantExpr(c), True
# Trivial comparisons
case (
PsEq(op1, op2) | PsGe(op1, op2) | PsLe(op1, op2)
) if op1.structurally_equal(op2):
true = self._typifer(PsConstantExpr(PsConstant(True, PsBoolType())))
return true, True
case (
PsNe(op1, op2) | PsGt(op1, op2) | PsLt(op1, op2)
) if op1.structurally_equal(op2):
false = self._typifer(PsConstantExpr(PsConstant(False, PsBoolType())))
return false, True
# end match: no idempotence or dominance encountered # end match: no idempotence or dominance encountered
# Detect constant expressions # Detect constant expressions
...@@ -239,6 +257,8 @@ class EliminateConstants: ...@@ -239,6 +257,8 @@ class EliminateConstants:
or (self._fold_relations and is_rel) or (self._fold_relations and is_rel)
) )
folded: PsConstant | None
match expr: match expr:
case PsNeg(operand) | PsNot(operand): case PsNeg(operand) | PsNot(operand):
if isinstance(operand, PsConstantExpr): if isinstance(operand, PsConstantExpr):
...@@ -252,7 +272,9 @@ class EliminateConstants: ...@@ -252,7 +272,9 @@ class EliminateConstants:
return expr, True return expr, True
case PsBinOp(op1, op2): case PsBinOp(op1, op2):
if isinstance(op1, PsConstantExpr) and isinstance(op2, PsConstantExpr): if isinstance(op1, PsConstantExpr) and isinstance(
op2, PsConstantExpr
):
v1 = op1.constant.value v1 = op1.constant.value
v2 = op2.constant.value v2 = op2.constant.value
...@@ -277,7 +299,7 @@ class EliminateConstants: ...@@ -277,7 +299,7 @@ class EliminateConstants:
return PsConstantExpr(folded), True return PsConstantExpr(folded), True
return expr, True return expr, True
case PsCall(PsMathFunction(), _): case PsCall(PsMathFunction(), _):
# TODO: Some math functions (min/max) might be safely folded # TODO: Some math functions (min/max) might be safely folded
return expr, True return expr, True
......
...@@ -120,3 +120,11 @@ def test_relations_folding(): ...@@ -120,3 +120,11 @@ def test_relations_folding():
expr = typify(PsEq(- i1, - i3)) expr = typify(PsEq(- i1, - i3))
result = elim(expr) result = elim(expr)
assert result.structurally_equal(false) assert result.structurally_equal(false)
expr = typify(PsEq(x + y, f1 * (x + y)))
result = elim(expr)
assert result.structurally_equal(true)
expr = typify(PsGt(x + y, f1 * (x + y)))
result = elim(expr)
assert result.structurally_equal(false)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment