From a6ec04ed47307ff94ad88858da0ff679b917940e Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 4 Apr 2024 16:11:24 +0200 Subject: [PATCH] Add folding of booleans and relations to EliminateConstans. Fix folding of unary operations. --- src/pystencils/backend/ast/expressions.py | 6 +- .../transformations/eliminate_constants.py | 133 +++++++++++++----- .../test_constant_elimination.py | 77 ++++++++-- 3 files changed, 164 insertions(+), 52 deletions(-) diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index aacddfc8c..cafab6701 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -149,7 +149,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): return self._constant == other._constant def __repr__(self) -> str: - return f"Constant({repr(self._constant)})" + return f"PsConstantExpr({repr(self._constant)})" class PsSubscript(PsLvalue, PsExpression): @@ -431,6 +431,10 @@ class PsUnOp(PsExpression): @property def python_operator(self) -> None | Callable[[Any], Any]: return None + + def __repr__(self) -> str: + opname = self.__class__.__name__ + return f"{opname}({repr(self._operand)})" class PsNeg(PsUnOp, PsNumericOpTrait): diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index 22ad740fa..d48c70a07 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -14,12 +14,25 @@ from ..ast.expressions import ( PsSub, PsMul, PsDiv, + PsAnd, + PsOr, + PsRel, + PsNeg, + PsNot, + PsCall ) from ..ast.util import AstEqWrapper from ..constants import PsConstant from ..symbols import PsSymbol -from ...types import PsIntegerType, PsIeeeFloatType, PsTypeError +from ..functions import PsMathFunction +from ...types import ( + PsIntegerType, + PsIeeeFloatType, + PsNumericType, + PsBoolType, + PsTypeError, +) __all__ = ["EliminateConstants"] @@ -94,6 +107,7 @@ class EliminateConstants: self._ctx = ctx self._fold_integers = True + self._fold_relations = True self._fold_floats = False self._extract_constant_exprs = extract_constant_exprs @@ -180,52 +194,93 @@ class EliminateConstants: case PsMul(other_op, PsConstantExpr(c)) if c.value == 0: return PsConstantExpr(c), True + # Logical idempotence + case PsAnd(PsConstantExpr(c), other_op) if c.value: + return other_op, all(subtree_constness) + + case PsAnd(other_op, PsConstantExpr(c)) if c.value: + return other_op, all(subtree_constness) + + case PsOr(PsConstantExpr(c), other_op) if not c.value: + return other_op, all(subtree_constness) + + case PsOr(other_op, PsConstantExpr(c)) if not c.value: + return other_op, all(subtree_constness) + + # Logical dominance + case PsAnd(PsConstantExpr(c), other_op) if not c.value: + return PsConstantExpr(c), True + + case PsAnd(other_op, PsConstantExpr(c)) if not c.value: + return PsConstantExpr(c), True + + case PsOr(PsConstantExpr(c), other_op) if c.value: + return PsConstantExpr(c), True + + case PsOr(other_op, PsConstantExpr(c)) if c.value: + return PsConstantExpr(c), True + # end match: no idempotence or dominance encountered # Detect constant expressions if all(subtree_constness): - # Fold binary expressions where possible - if isinstance(expr, PsBinOp): - op1_transformed = expr.operand1 - op2_transformed = expr.operand2 - - if isinstance(op1_transformed, PsConstantExpr) and isinstance( - op2_transformed, PsConstantExpr - ): - v1 = op1_transformed.constant.value - v2 = op2_transformed.constant.value - - # assume they are of equal type - dtype = op1_transformed.constant.dtype - - is_int = isinstance(dtype, PsIntegerType) - is_float = isinstance(dtype, PsIeeeFloatType) - - if (self._fold_integers and is_int) or ( - self._fold_floats and is_float - ): + dtype = expr.get_dtype() + assert isinstance(dtype, PsNumericType) + + is_int = isinstance(dtype, PsIntegerType) + is_float = isinstance(dtype, PsIeeeFloatType) + is_bool = isinstance(dtype, PsBoolType) + is_rel = isinstance(expr, PsRel) + + do_fold = ( + is_bool + or (self._fold_integers and is_int) + or (self._fold_floats and is_float) + or (self._fold_relations and is_rel) + ) + + match expr: + case PsNeg(operand) | PsNot(operand): + if isinstance(operand, PsConstantExpr): + val = operand.constant.value py_operator = expr.python_operator - folded = None - if py_operator is not None: - folded = PsConstant( - py_operator(v1, v2), - dtype, - ) - elif isinstance(expr, PsDiv): - if isinstance(dtype, PsIntegerType): - pass - # TODO: C integer division! - # folded = PsConstant(v1 // v2, dtype) - elif isinstance(dtype, PsIeeeFloatType): - folded = PsConstant(v1 / v2, dtype) - - if folded is not None: + if do_fold and py_operator is not None: + folded = PsConstant(py_operator(val), dtype) return PsConstantExpr(folded), True - expr.operand1 = op1_transformed - expr.operand2 = op2_transformed - return expr, True + return expr, True + + case PsBinOp(op1, op2): + if isinstance(op1, PsConstantExpr) and isinstance(op2, PsConstantExpr): + v1 = op1.constant.value + v2 = op2.constant.value + + if do_fold: + py_operator = expr.python_operator + + folded = None + if py_operator is not None: + folded = PsConstant( + py_operator(v1, v2), + dtype, + ) + elif isinstance(expr, PsDiv): + if is_int: + pass + # TODO: C integer division! + # folded = PsConstant(v1 // v2, dtype) + elif isinstance(dtype, PsIeeeFloatType): + folded = PsConstant(v1 / v2, dtype) + + if folded is not None: + return PsConstantExpr(folded), True + + return expr, True + + case PsCall(PsMathFunction(), _): + # TODO: Some math functions (min/max) might be safely folded + return expr, True # end if: this expression is not constant # If required, extract constant subexpressions diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py index b2ac6fc5a..f8337d4e4 100644 --- a/tests/nbackend/transformations/test_constant_elimination.py +++ b/tests/nbackend/transformations/test_constant_elimination.py @@ -1,12 +1,22 @@ -from pystencils.backend.kernelcreation import KernelCreationContext +from pystencils.backend.kernelcreation import KernelCreationContext, Typifier from pystencils.backend.ast.expressions import PsExpression, PsConstantExpr from pystencils.backend.symbols import PsSymbol from pystencils.backend.constants import PsConstant from pystencils.backend.transformations import EliminateConstants -from pystencils.types.quick import Int, Fp +from pystencils.backend.ast.expressions import ( + PsAnd, + PsOr, + PsNot, + PsEq, + PsGt, +) -x, y, z = [PsExpression.make(PsSymbol(name)) for name in "xyz"] +from pystencils.types.quick import Int, Fp, Bool + +x, y, z = [PsExpression.make(PsSymbol(name, Fp(32))) for name in "xyz"] +p, q, r = [PsExpression.make(PsSymbol(name, Int(32))) for name in "xyz"] +a, b, c = [PsExpression.make(PsSymbol(name, Bool())) for name in "abc"] f3p5 = PsExpression.make(PsConstant(3.5, Fp(32))) f42 = PsExpression.make(PsConstant(42, Fp(32))) @@ -18,52 +28,95 @@ i0 = PsExpression.make(PsConstant(0, Int(32))) i1 = PsExpression.make(PsConstant(1, Int(32))) i3 = PsExpression.make(PsConstant(3, Int(32))) +im3 = PsExpression.make(PsConstant(-3, Int(32))) i12 = PsExpression.make(PsConstant(12, Int(32))) +true = PsExpression.make(PsConstant(True, Bool())) +false = PsExpression.make(PsConstant(False, Bool())) + def test_idempotence(): ctx = KernelCreationContext() + typify = Typifier(ctx) elim = EliminateConstants(ctx) - expr = f42 * (f1 + f0) - f0 + expr = typify(f42 * (f1 + f0) - f0) result = elim(expr) assert isinstance(result, PsConstantExpr) and result.structurally_equal(f42) - expr = (x + f0) * f3p5 + (f1 * y + f0) * f42 + expr = typify((x + f0) * f3p5 + (f1 * y + f0) * f42) result = elim(expr) assert result.structurally_equal(x * f3p5 + y * f42) - expr = (f3p5 * f1) + (f42 * f1) + expr = typify((f3p5 * f1) + (f42 * f1)) result = elim(expr) # do not fold floats by default assert expr.structurally_equal(f3p5 + f42) - expr = f1 * x + f0 + (f0 + f0 + f1 + f0) * y + expr = typify(f1 * x + f0 + (f0 + f0 + f1 + f0) * y) result = elim(expr) assert result.structurally_equal(x + y) def test_int_folding(): ctx = KernelCreationContext() + typify = Typifier(ctx) elim = EliminateConstants(ctx) - expr = (i1 * x + i1 * i3) + i1 * i12 + expr = typify((i1 * p + i1 * -i3) + i1 * i12) result = elim(expr) - assert result.structurally_equal((x + i3) + i12) + assert result.structurally_equal((p + im3) + i12) - expr = (i1 + i1 + i1 + i0 + i0 + i1) * (i1 + i1 + i1) + expr = typify((i1 + i1 + i1 + i0 + i0 + i1) * (i1 + i1 + i1)) result = elim(expr) assert result.structurally_equal(i12) def test_zero_dominance(): ctx = KernelCreationContext() + typify = Typifier(ctx) elim = EliminateConstants(ctx) - expr = (f0 * x) + (y * f0) + f1 + expr = typify((f0 * x) + (y * f0) + f1) result = elim(expr) assert result.structurally_equal(f1) - expr = (i3 + i12 * (x + y) + x / (i3 * y)) * i0 + expr = typify((i3 + i12 * (p + q) + p / (i3 * q)) * i0) result = elim(expr) assert result.structurally_equal(i0) + + +def test_boolean_folding(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + elim = EliminateConstants(ctx) + + expr = typify(PsNot(PsAnd(false, PsOr(true, a)))) + result = elim(expr) + assert result.structurally_equal(true) + + expr = typify(PsOr(PsAnd(a, b), PsNot(false))) + result = elim(expr) + assert result.structurally_equal(true) + + expr = typify(PsAnd(c, PsAnd(true, PsAnd(a, PsOr(false, b))))) + result = elim(expr) + assert result.structurally_equal(PsAnd(c, PsAnd(a, b))) + + +def test_relations_folding(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + elim = EliminateConstants(ctx) + + expr = typify(PsGt(p * i0, - i1)) + result = elim(expr) + assert result.structurally_equal(true) + + expr = typify(PsEq(i1 + i1 + i1, i3)) + result = elim(expr) + assert result.structurally_equal(true) + + expr = typify(PsEq(- i1, - i3)) + result = elim(expr) + assert result.structurally_equal(false) -- GitLab