From edf38f261857da0e4a596fa4998ee28dba3a651f Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 5 Apr 2024 11:03:32 +0200 Subject: [PATCH] incorporate review --- src/pystencils/backend/ast/expressions.py | 2 +- src/pystencils/backend/emission.py | 12 ++++++------ .../backend/transformations/eliminate_constants.py | 10 +++++----- .../transformations/test_constant_elimination.py | 2 +- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index aec23809b..0666d9687 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -590,7 +590,7 @@ class PsAnd(PsBinOp, PsBoolOpTrait): class PsOr(PsBinOp, PsBoolOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.and_ + return operator.or_ class PsNot(PsUnOp, PsBoolOpTrait): diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 135fb3d38..c30d2fd67 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -96,7 +96,7 @@ class Ops(Enum): LeftShift = (7, LR.Left) RightShift = (7, LR.Left) - RelOp = (9, LR.Left) + IneqOp = (9, LR.Left) EqOp = (10, LR.Left) @@ -297,7 +297,7 @@ class CAstPrinter: operand_code = self.visit(operand, pc) pc.pop_op() - return pc.parenthesize(f"!{operand_code}", Ops.Neg) + return pc.parenthesize(f"!{operand_code}", Ops.Not) case PsDeref(operand): pc.push_op(Ops.Deref, LR.Right) @@ -373,12 +373,12 @@ class CAstPrinter: case PsNe(): return ("!=", Ops.EqOp) case PsGt(): - return (">", Ops.RelOp) + return (">", Ops.IneqOp) case PsGe(): - return (">=", Ops.RelOp) + return (">=", Ops.IneqOp) case PsLt(): - return ("<", Ops.RelOp) + return ("<", Ops.IneqOp) case PsLe(): - return ("<=", Ops.RelOp) + return ("<=", Ops.IneqOp) case _: assert False diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index d79f346b2..7678dbd8c 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -109,7 +109,7 @@ class EliminateConstants: self, ctx: KernelCreationContext, extract_constant_exprs: bool = False ): self._ctx = ctx - self._typifer = Typifier(ctx) + self._typify = Typifier(ctx) self._fold_integers = True self._fold_relations = True @@ -229,13 +229,13 @@ class EliminateConstants: case ( PsEq(op1, op2) | PsGe(op1, op2) | PsLe(op1, op2) ) if op1.structurally_equal(op2): - true = self._typifer(PsConstantExpr(PsConstant(True, PsBoolType()))) + true = self._typify(PsConstantExpr(PsConstant(True, PsBoolType()))) return true, True case ( PsNe(op1, op2) | PsGt(op1, op2) | PsLt(op1, op2) ) if op1.structurally_equal(op2): - false = self._typifer(PsConstantExpr(PsConstant(False, PsBoolType()))) + false = self._typify(PsConstantExpr(PsConstant(False, PsBoolType()))) return false, True # end match: no idempotence or dominance encountered @@ -267,7 +267,7 @@ class EliminateConstants: if do_fold and py_operator is not None: folded = PsConstant(py_operator(val), dtype) - return PsConstantExpr(folded), True + return self._typify(PsConstantExpr(folded)), True return expr, True @@ -296,7 +296,7 @@ class EliminateConstants: folded = PsConstant(v1 / v2, dtype) if folded is not None: - return PsConstantExpr(folded), True + return self._typify(PsConstantExpr(folded)), True return expr, True diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py index 1bea3f516..48df23ee1 100644 --- a/tests/nbackend/transformations/test_constant_elimination.py +++ b/tests/nbackend/transformations/test_constant_elimination.py @@ -15,7 +15,7 @@ from pystencils.backend.ast.expressions import ( from pystencils.types.quick import Int, Fp, Bool x, y, z = [PsExpression.make(PsSymbol(name, Fp(32))) for name in "xyz"] -p, q, r = [PsExpression.make(PsSymbol(name, Int(32))) for name in "xyz"] +p, q, r = [PsExpression.make(PsSymbol(name, Int(32))) for name in "pqr"] a, b, c = [PsExpression.make(PsSymbol(name, Bool())) for name in "abc"] f3p5 = PsExpression.make(PsConstant(3.5, Fp(32))) -- GitLab