diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index aec23809b5d0657ccd303ccc9d4dfcc1a6e152a2..0666d96873d4bdd3d722a7912b6e704b4aee1cf8 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -590,7 +590,7 @@ class PsAnd(PsBinOp, PsBoolOpTrait): class PsOr(PsBinOp, PsBoolOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.and_ + return operator.or_ class PsNot(PsUnOp, PsBoolOpTrait): diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 135fb3d38cebfb6e9ff7e3861710528624a752c2..c30d2fd6772bdc571bf0acad1d7f436f39dfae90 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -96,7 +96,7 @@ class Ops(Enum): LeftShift = (7, LR.Left) RightShift = (7, LR.Left) - RelOp = (9, LR.Left) + IneqOp = (9, LR.Left) EqOp = (10, LR.Left) @@ -297,7 +297,7 @@ class CAstPrinter: operand_code = self.visit(operand, pc) pc.pop_op() - return pc.parenthesize(f"!{operand_code}", Ops.Neg) + return pc.parenthesize(f"!{operand_code}", Ops.Not) case PsDeref(operand): pc.push_op(Ops.Deref, LR.Right) @@ -373,12 +373,12 @@ class CAstPrinter: case PsNe(): return ("!=", Ops.EqOp) case PsGt(): - return (">", Ops.RelOp) + return (">", Ops.IneqOp) case PsGe(): - return (">=", Ops.RelOp) + return (">=", Ops.IneqOp) case PsLt(): - return ("<", Ops.RelOp) + return ("<", Ops.IneqOp) case PsLe(): - return ("<=", Ops.RelOp) + return ("<=", Ops.IneqOp) case _: assert False diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index d79f346b25c87be539063013005693128716f868..7678dbd8c6ce783585fb7095b201e9f92e65e485 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -109,7 +109,7 @@ class EliminateConstants: self, ctx: KernelCreationContext, extract_constant_exprs: bool = False ): self._ctx = ctx - self._typifer = Typifier(ctx) + self._typify = Typifier(ctx) self._fold_integers = True self._fold_relations = True @@ -229,13 +229,13 @@ class EliminateConstants: case ( PsEq(op1, op2) | PsGe(op1, op2) | PsLe(op1, op2) ) if op1.structurally_equal(op2): - true = self._typifer(PsConstantExpr(PsConstant(True, PsBoolType()))) + true = self._typify(PsConstantExpr(PsConstant(True, PsBoolType()))) return true, True case ( PsNe(op1, op2) | PsGt(op1, op2) | PsLt(op1, op2) ) if op1.structurally_equal(op2): - false = self._typifer(PsConstantExpr(PsConstant(False, PsBoolType()))) + false = self._typify(PsConstantExpr(PsConstant(False, PsBoolType()))) return false, True # end match: no idempotence or dominance encountered @@ -267,7 +267,7 @@ class EliminateConstants: if do_fold and py_operator is not None: folded = PsConstant(py_operator(val), dtype) - return PsConstantExpr(folded), True + return self._typify(PsConstantExpr(folded)), True return expr, True @@ -296,7 +296,7 @@ class EliminateConstants: folded = PsConstant(v1 / v2, dtype) if folded is not None: - return PsConstantExpr(folded), True + return self._typify(PsConstantExpr(folded)), True return expr, True diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py index 1bea3f5164b5f10d6a1940b197994a95490803ba..48df23ee193e24ff6736d74c22317f04dddc056c 100644 --- a/tests/nbackend/transformations/test_constant_elimination.py +++ b/tests/nbackend/transformations/test_constant_elimination.py @@ -15,7 +15,7 @@ from pystencils.backend.ast.expressions import ( from pystencils.types.quick import Int, Fp, Bool x, y, z = [PsExpression.make(PsSymbol(name, Fp(32))) for name in "xyz"] -p, q, r = [PsExpression.make(PsSymbol(name, Int(32))) for name in "xyz"] +p, q, r = [PsExpression.make(PsSymbol(name, Int(32))) for name in "pqr"] a, b, c = [PsExpression.make(PsSymbol(name, Bool())) for name in "abc"] f3p5 = PsExpression.make(PsConstant(3.5, Fp(32)))