diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index cafab67011afa8b2049b6c3f7301dbad7624af24..01df0ccb6d1a6f530b8e47d872a78ad29852f6c1 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -431,7 +431,7 @@ class PsUnOp(PsExpression): @property def python_operator(self) -> None | Callable[[Any], Any]: return None - + def __repr__(self) -> str: opname = self.__class__.__name__ return f"{opname}({repr(self._operand)})" diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index d48c70a0715b2374b215ce82fb4084f97a65bdd9..d79f346b25c87be539063013005693128716f868 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -19,7 +19,13 @@ from ..ast.expressions import ( PsRel, PsNeg, PsNot, - PsCall + PsCall, + PsEq, + PsGe, + PsLe, + PsLt, + PsGt, + PsNe, ) from ..ast.util import AstEqWrapper @@ -43,8 +49,6 @@ class ECContext: self._ctx = ctx self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict() - self._typifier = Typifier(ctx) - from ..emission import CAstPrinter self._printer = CAstPrinter(0) @@ -72,7 +76,7 @@ class ECContext: return f"__c_{code}" def extract_expression(self, expr: PsExpression) -> PsSymbolExpr: - expr, dtype = self._typifier.typify_expression(expr) + dtype = expr.get_dtype() expr_wrapped = AstEqWrapper(expr) if expr_wrapped not in self._extracted_constants: @@ -105,6 +109,7 @@ class EliminateConstants: self, ctx: KernelCreationContext, extract_constant_exprs: bool = False ): self._ctx = ctx + self._typifer = Typifier(ctx) self._fold_integers = True self._fold_relations = True @@ -160,7 +165,7 @@ class EliminateConstants: expr.children = [r[0] for r in subtree_results] subtree_constness = [r[1] for r in subtree_results] - # Eliminate idempotence and dominance + # Eliminate idempotence, dominance, and trivial relations match expr: # Additive idempotence: Addition and subtraction of zero case PsAdd(PsConstantExpr(c), other_op) if c.value == 0: @@ -220,6 +225,19 @@ class EliminateConstants: case PsOr(other_op, PsConstantExpr(c)) if c.value: return PsConstantExpr(c), True + # Trivial comparisons + case ( + PsEq(op1, op2) | PsGe(op1, op2) | PsLe(op1, op2) + ) if op1.structurally_equal(op2): + true = self._typifer(PsConstantExpr(PsConstant(True, PsBoolType()))) + return true, True + + case ( + PsNe(op1, op2) | PsGt(op1, op2) | PsLt(op1, op2) + ) if op1.structurally_equal(op2): + false = self._typifer(PsConstantExpr(PsConstant(False, PsBoolType()))) + return false, True + # end match: no idempotence or dominance encountered # Detect constant expressions @@ -239,6 +257,8 @@ class EliminateConstants: or (self._fold_relations and is_rel) ) + folded: PsConstant | None + match expr: case PsNeg(operand) | PsNot(operand): if isinstance(operand, PsConstantExpr): @@ -252,7 +272,9 @@ class EliminateConstants: return expr, True case PsBinOp(op1, op2): - if isinstance(op1, PsConstantExpr) and isinstance(op2, PsConstantExpr): + if isinstance(op1, PsConstantExpr) and isinstance( + op2, PsConstantExpr + ): v1 = op1.constant.value v2 = op2.constant.value @@ -277,7 +299,7 @@ class EliminateConstants: return PsConstantExpr(folded), True return expr, True - + case PsCall(PsMathFunction(), _): # TODO: Some math functions (min/max) might be safely folded return expr, True diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py index f8337d4e4f359487df83b4c44d211507f5f3b392..1bea3f5164b5f10d6a1940b197994a95490803ba 100644 --- a/tests/nbackend/transformations/test_constant_elimination.py +++ b/tests/nbackend/transformations/test_constant_elimination.py @@ -120,3 +120,11 @@ def test_relations_folding(): expr = typify(PsEq(- i1, - i3)) result = elim(expr) assert result.structurally_equal(false) + + expr = typify(PsEq(x + y, f1 * (x + y))) + result = elim(expr) + assert result.structurally_equal(true) + + expr = typify(PsGt(x + y, f1 * (x + y))) + result = elim(expr) + assert result.structurally_equal(false)