From a6ec04ed47307ff94ad88858da0ff679b917940e Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 4 Apr 2024 16:11:24 +0200
Subject: [PATCH] Add folding of booleans and relations to EliminateConstans.
 Fix folding of unary operations.

---
 src/pystencils/backend/ast/expressions.py     |   6 +-
 .../transformations/eliminate_constants.py    | 133 +++++++++++++-----
 .../test_constant_elimination.py              |  77 ++++++++--
 3 files changed, 164 insertions(+), 52 deletions(-)

diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py
index aacddfc8c..cafab6701 100644
--- a/src/pystencils/backend/ast/expressions.py
+++ b/src/pystencils/backend/ast/expressions.py
@@ -149,7 +149,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression):
         return self._constant == other._constant
 
     def __repr__(self) -> str:
-        return f"Constant({repr(self._constant)})"
+        return f"PsConstantExpr({repr(self._constant)})"
 
 
 class PsSubscript(PsLvalue, PsExpression):
@@ -431,6 +431,10 @@ class PsUnOp(PsExpression):
     @property
     def python_operator(self) -> None | Callable[[Any], Any]:
         return None
+    
+    def __repr__(self) -> str:
+        opname = self.__class__.__name__
+        return f"{opname}({repr(self._operand)})"
 
 
 class PsNeg(PsUnOp, PsNumericOpTrait):
diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py
index 22ad740fa..d48c70a07 100644
--- a/src/pystencils/backend/transformations/eliminate_constants.py
+++ b/src/pystencils/backend/transformations/eliminate_constants.py
@@ -14,12 +14,25 @@ from ..ast.expressions import (
     PsSub,
     PsMul,
     PsDiv,
+    PsAnd,
+    PsOr,
+    PsRel,
+    PsNeg,
+    PsNot,
+    PsCall
 )
 from ..ast.util import AstEqWrapper
 
 from ..constants import PsConstant
 from ..symbols import PsSymbol
-from ...types import PsIntegerType, PsIeeeFloatType, PsTypeError
+from ..functions import PsMathFunction
+from ...types import (
+    PsIntegerType,
+    PsIeeeFloatType,
+    PsNumericType,
+    PsBoolType,
+    PsTypeError,
+)
 
 
 __all__ = ["EliminateConstants"]
@@ -94,6 +107,7 @@ class EliminateConstants:
         self._ctx = ctx
 
         self._fold_integers = True
+        self._fold_relations = True
         self._fold_floats = False
         self._extract_constant_exprs = extract_constant_exprs
 
@@ -180,52 +194,93 @@ class EliminateConstants:
             case PsMul(other_op, PsConstantExpr(c)) if c.value == 0:
                 return PsConstantExpr(c), True
 
+            #   Logical idempotence
+            case PsAnd(PsConstantExpr(c), other_op) if c.value:
+                return other_op, all(subtree_constness)
+
+            case PsAnd(other_op, PsConstantExpr(c)) if c.value:
+                return other_op, all(subtree_constness)
+
+            case PsOr(PsConstantExpr(c), other_op) if not c.value:
+                return other_op, all(subtree_constness)
+
+            case PsOr(other_op, PsConstantExpr(c)) if not c.value:
+                return other_op, all(subtree_constness)
+
+            #   Logical dominance
+            case PsAnd(PsConstantExpr(c), other_op) if not c.value:
+                return PsConstantExpr(c), True
+
+            case PsAnd(other_op, PsConstantExpr(c)) if not c.value:
+                return PsConstantExpr(c), True
+
+            case PsOr(PsConstantExpr(c), other_op) if c.value:
+                return PsConstantExpr(c), True
+
+            case PsOr(other_op, PsConstantExpr(c)) if c.value:
+                return PsConstantExpr(c), True
+
         # end match: no idempotence or dominance encountered
 
         #   Detect constant expressions
         if all(subtree_constness):
-            #   Fold binary expressions where possible
-            if isinstance(expr, PsBinOp):
-                op1_transformed = expr.operand1
-                op2_transformed = expr.operand2
-
-                if isinstance(op1_transformed, PsConstantExpr) and isinstance(
-                    op2_transformed, PsConstantExpr
-                ):
-                    v1 = op1_transformed.constant.value
-                    v2 = op2_transformed.constant.value
-
-                    # assume they are of equal type
-                    dtype = op1_transformed.constant.dtype
-
-                    is_int = isinstance(dtype, PsIntegerType)
-                    is_float = isinstance(dtype, PsIeeeFloatType)
-
-                    if (self._fold_integers and is_int) or (
-                        self._fold_floats and is_float
-                    ):
+            dtype = expr.get_dtype()
+            assert isinstance(dtype, PsNumericType)
+
+            is_int = isinstance(dtype, PsIntegerType)
+            is_float = isinstance(dtype, PsIeeeFloatType)
+            is_bool = isinstance(dtype, PsBoolType)
+            is_rel = isinstance(expr, PsRel)
+
+            do_fold = (
+                is_bool
+                or (self._fold_integers and is_int)
+                or (self._fold_floats and is_float)
+                or (self._fold_relations and is_rel)
+            )
+
+            match expr:
+                case PsNeg(operand) | PsNot(operand):
+                    if isinstance(operand, PsConstantExpr):
+                        val = operand.constant.value
                         py_operator = expr.python_operator
 
-                        folded = None
-                        if py_operator is not None:
-                            folded = PsConstant(
-                                py_operator(v1, v2),
-                                dtype,
-                            )
-                        elif isinstance(expr, PsDiv):
-                            if isinstance(dtype, PsIntegerType):
-                                pass
-                                #   TODO: C integer division!
-                                # folded = PsConstant(v1 // v2, dtype)
-                            elif isinstance(dtype, PsIeeeFloatType):
-                                folded = PsConstant(v1 / v2, dtype)
-
-                        if folded is not None:
+                        if do_fold and py_operator is not None:
+                            folded = PsConstant(py_operator(val), dtype)
                             return PsConstantExpr(folded), True
 
-                expr.operand1 = op1_transformed
-                expr.operand2 = op2_transformed
-                return expr, True
+                    return expr, True
+
+                case PsBinOp(op1, op2):
+                    if isinstance(op1, PsConstantExpr) and isinstance(op2, PsConstantExpr):
+                        v1 = op1.constant.value
+                        v2 = op2.constant.value
+
+                        if do_fold:
+                            py_operator = expr.python_operator
+
+                            folded = None
+                            if py_operator is not None:
+                                folded = PsConstant(
+                                    py_operator(v1, v2),
+                                    dtype,
+                                )
+                            elif isinstance(expr, PsDiv):
+                                if is_int:
+                                    pass
+                                    #   TODO: C integer division!
+                                    # folded = PsConstant(v1 // v2, dtype)
+                                elif isinstance(dtype, PsIeeeFloatType):
+                                    folded = PsConstant(v1 / v2, dtype)
+
+                            if folded is not None:
+                                return PsConstantExpr(folded), True
+
+                    return expr, True
+                
+                case PsCall(PsMathFunction(), _):
+                    #   TODO: Some math functions (min/max) might be safely folded
+                    return expr, True
         # end if: this expression is not constant
 
         #   If required, extract constant subexpressions
diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py
index b2ac6fc5a..f8337d4e4 100644
--- a/tests/nbackend/transformations/test_constant_elimination.py
+++ b/tests/nbackend/transformations/test_constant_elimination.py
@@ -1,12 +1,22 @@
-from pystencils.backend.kernelcreation import KernelCreationContext
+from pystencils.backend.kernelcreation import KernelCreationContext, Typifier
 from pystencils.backend.ast.expressions import PsExpression, PsConstantExpr
 from pystencils.backend.symbols import PsSymbol
 from pystencils.backend.constants import PsConstant
 from pystencils.backend.transformations import EliminateConstants
 
-from pystencils.types.quick import Int, Fp
+from pystencils.backend.ast.expressions import (
+    PsAnd,
+    PsOr,
+    PsNot,
+    PsEq,
+    PsGt,
+)
 
-x, y, z = [PsExpression.make(PsSymbol(name)) for name in "xyz"]
+from pystencils.types.quick import Int, Fp, Bool
+
+x, y, z = [PsExpression.make(PsSymbol(name, Fp(32))) for name in "xyz"]
+p, q, r = [PsExpression.make(PsSymbol(name, Int(32))) for name in "xyz"]
+a, b, c = [PsExpression.make(PsSymbol(name, Bool())) for name in "abc"]
 
 f3p5 = PsExpression.make(PsConstant(3.5, Fp(32)))
 f42 = PsExpression.make(PsConstant(42, Fp(32)))
@@ -18,52 +28,95 @@ i0 = PsExpression.make(PsConstant(0, Int(32)))
 i1 = PsExpression.make(PsConstant(1, Int(32)))
 
 i3 = PsExpression.make(PsConstant(3, Int(32)))
+im3 = PsExpression.make(PsConstant(-3, Int(32)))
 i12 = PsExpression.make(PsConstant(12, Int(32)))
 
+true = PsExpression.make(PsConstant(True, Bool()))
+false = PsExpression.make(PsConstant(False, Bool()))
+
 
 def test_idempotence():
     ctx = KernelCreationContext()
+    typify = Typifier(ctx)
     elim = EliminateConstants(ctx)
 
-    expr = f42 * (f1 + f0) - f0
+    expr = typify(f42 * (f1 + f0) - f0)
     result = elim(expr)
     assert isinstance(result, PsConstantExpr) and result.structurally_equal(f42)
 
-    expr = (x + f0) * f3p5 + (f1 * y + f0) * f42
+    expr = typify((x + f0) * f3p5 + (f1 * y + f0) * f42)
     result = elim(expr)
     assert result.structurally_equal(x * f3p5 + y * f42)
 
-    expr = (f3p5 * f1) + (f42 * f1)
+    expr = typify((f3p5 * f1) + (f42 * f1))
     result = elim(expr)
     #   do not fold floats by default
     assert expr.structurally_equal(f3p5 + f42)
 
-    expr = f1 * x + f0 + (f0 + f0 + f1 + f0) * y
+    expr = typify(f1 * x + f0 + (f0 + f0 + f1 + f0) * y)
     result = elim(expr)
     assert result.structurally_equal(x + y)
 
 
 def test_int_folding():
     ctx = KernelCreationContext()
+    typify = Typifier(ctx)
     elim = EliminateConstants(ctx)
 
-    expr = (i1 * x + i1 * i3) + i1 * i12
+    expr = typify((i1 * p + i1 * -i3) + i1 * i12)
     result = elim(expr)
-    assert result.structurally_equal((x + i3) + i12)
+    assert result.structurally_equal((p + im3) + i12)
 
-    expr = (i1 + i1 + i1 + i0 + i0 + i1) * (i1 + i1 + i1)
+    expr = typify((i1 + i1 + i1 + i0 + i0 + i1) * (i1 + i1 + i1))
     result = elim(expr)
     assert result.structurally_equal(i12)
 
 
 def test_zero_dominance():
     ctx = KernelCreationContext()
+    typify = Typifier(ctx)
     elim = EliminateConstants(ctx)
 
-    expr = (f0 * x) + (y * f0) + f1
+    expr = typify((f0 * x) + (y * f0) + f1)
     result = elim(expr)
     assert result.structurally_equal(f1)
 
-    expr = (i3 + i12 * (x + y) + x / (i3 * y)) * i0
+    expr = typify((i3 + i12 * (p + q) + p / (i3 * q)) * i0)
     result = elim(expr)
     assert result.structurally_equal(i0)
+
+
+def test_boolean_folding():
+    ctx = KernelCreationContext()
+    typify = Typifier(ctx)
+    elim = EliminateConstants(ctx)
+
+    expr = typify(PsNot(PsAnd(false, PsOr(true, a))))
+    result = elim(expr)
+    assert result.structurally_equal(true)
+
+    expr = typify(PsOr(PsAnd(a, b), PsNot(false)))
+    result = elim(expr)
+    assert result.structurally_equal(true)
+
+    expr = typify(PsAnd(c, PsAnd(true, PsAnd(a, PsOr(false, b)))))
+    result = elim(expr)
+    assert result.structurally_equal(PsAnd(c, PsAnd(a, b)))
+
+
+def test_relations_folding():
+    ctx = KernelCreationContext()
+    typify = Typifier(ctx)
+    elim = EliminateConstants(ctx)
+
+    expr = typify(PsGt(p * i0, - i1))
+    result = elim(expr)
+    assert result.structurally_equal(true)
+
+    expr = typify(PsEq(i1 + i1 + i1, i3))
+    result = elim(expr)
+    assert result.structurally_equal(true)
+
+    expr = typify(PsEq(- i1, - i3))
+    result = elim(expr)
+    assert result.structurally_equal(false)
-- 
GitLab