Skip to content
Snippets Groups Projects
Commit edf38f26 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

incorporate review

parent 93d31273
No related branches found
No related tags found
1 merge request!375Support for Boolean Operations and Relations
Pipeline #64935 passed
...@@ -590,7 +590,7 @@ class PsAnd(PsBinOp, PsBoolOpTrait): ...@@ -590,7 +590,7 @@ class PsAnd(PsBinOp, PsBoolOpTrait):
class PsOr(PsBinOp, PsBoolOpTrait): class PsOr(PsBinOp, PsBoolOpTrait):
@property @property
def python_operator(self) -> Callable[[Any, Any], Any] | None: def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.and_ return operator.or_
class PsNot(PsUnOp, PsBoolOpTrait): class PsNot(PsUnOp, PsBoolOpTrait):
......
...@@ -96,7 +96,7 @@ class Ops(Enum): ...@@ -96,7 +96,7 @@ class Ops(Enum):
LeftShift = (7, LR.Left) LeftShift = (7, LR.Left)
RightShift = (7, LR.Left) RightShift = (7, LR.Left)
RelOp = (9, LR.Left) IneqOp = (9, LR.Left)
EqOp = (10, LR.Left) EqOp = (10, LR.Left)
...@@ -297,7 +297,7 @@ class CAstPrinter: ...@@ -297,7 +297,7 @@ class CAstPrinter:
operand_code = self.visit(operand, pc) operand_code = self.visit(operand, pc)
pc.pop_op() pc.pop_op()
return pc.parenthesize(f"!{operand_code}", Ops.Neg) return pc.parenthesize(f"!{operand_code}", Ops.Not)
case PsDeref(operand): case PsDeref(operand):
pc.push_op(Ops.Deref, LR.Right) pc.push_op(Ops.Deref, LR.Right)
...@@ -373,12 +373,12 @@ class CAstPrinter: ...@@ -373,12 +373,12 @@ class CAstPrinter:
case PsNe(): case PsNe():
return ("!=", Ops.EqOp) return ("!=", Ops.EqOp)
case PsGt(): case PsGt():
return (">", Ops.RelOp) return (">", Ops.IneqOp)
case PsGe(): case PsGe():
return (">=", Ops.RelOp) return (">=", Ops.IneqOp)
case PsLt(): case PsLt():
return ("<", Ops.RelOp) return ("<", Ops.IneqOp)
case PsLe(): case PsLe():
return ("<=", Ops.RelOp) return ("<=", Ops.IneqOp)
case _: case _:
assert False assert False
...@@ -109,7 +109,7 @@ class EliminateConstants: ...@@ -109,7 +109,7 @@ class EliminateConstants:
self, ctx: KernelCreationContext, extract_constant_exprs: bool = False self, ctx: KernelCreationContext, extract_constant_exprs: bool = False
): ):
self._ctx = ctx self._ctx = ctx
self._typifer = Typifier(ctx) self._typify = Typifier(ctx)
self._fold_integers = True self._fold_integers = True
self._fold_relations = True self._fold_relations = True
...@@ -229,13 +229,13 @@ class EliminateConstants: ...@@ -229,13 +229,13 @@ class EliminateConstants:
case ( case (
PsEq(op1, op2) | PsGe(op1, op2) | PsLe(op1, op2) PsEq(op1, op2) | PsGe(op1, op2) | PsLe(op1, op2)
) if op1.structurally_equal(op2): ) if op1.structurally_equal(op2):
true = self._typifer(PsConstantExpr(PsConstant(True, PsBoolType()))) true = self._typify(PsConstantExpr(PsConstant(True, PsBoolType())))
return true, True return true, True
case ( case (
PsNe(op1, op2) | PsGt(op1, op2) | PsLt(op1, op2) PsNe(op1, op2) | PsGt(op1, op2) | PsLt(op1, op2)
) if op1.structurally_equal(op2): ) if op1.structurally_equal(op2):
false = self._typifer(PsConstantExpr(PsConstant(False, PsBoolType()))) false = self._typify(PsConstantExpr(PsConstant(False, PsBoolType())))
return false, True return false, True
# end match: no idempotence or dominance encountered # end match: no idempotence or dominance encountered
...@@ -267,7 +267,7 @@ class EliminateConstants: ...@@ -267,7 +267,7 @@ class EliminateConstants:
if do_fold and py_operator is not None: if do_fold and py_operator is not None:
folded = PsConstant(py_operator(val), dtype) folded = PsConstant(py_operator(val), dtype)
return PsConstantExpr(folded), True return self._typify(PsConstantExpr(folded)), True
return expr, True return expr, True
...@@ -296,7 +296,7 @@ class EliminateConstants: ...@@ -296,7 +296,7 @@ class EliminateConstants:
folded = PsConstant(v1 / v2, dtype) folded = PsConstant(v1 / v2, dtype)
if folded is not None: if folded is not None:
return PsConstantExpr(folded), True return self._typify(PsConstantExpr(folded)), True
return expr, True return expr, True
......
...@@ -15,7 +15,7 @@ from pystencils.backend.ast.expressions import ( ...@@ -15,7 +15,7 @@ from pystencils.backend.ast.expressions import (
from pystencils.types.quick import Int, Fp, Bool from pystencils.types.quick import Int, Fp, Bool
x, y, z = [PsExpression.make(PsSymbol(name, Fp(32))) for name in "xyz"] x, y, z = [PsExpression.make(PsSymbol(name, Fp(32))) for name in "xyz"]
p, q, r = [PsExpression.make(PsSymbol(name, Int(32))) for name in "xyz"] p, q, r = [PsExpression.make(PsSymbol(name, Int(32))) for name in "pqr"]
a, b, c = [PsExpression.make(PsSymbol(name, Bool())) for name in "abc"] a, b, c = [PsExpression.make(PsSymbol(name, Bool())) for name in "abc"]
f3p5 = PsExpression.make(PsConstant(3.5, Fp(32))) f3p5 = PsExpression.make(PsConstant(3.5, Fp(32)))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment