Skip to content
Snippets Groups Projects
Commit 6be4b5af authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Fix multiarg min/max freeze, introduce test case, fix MathFunction equality

parent cbf8ada5
No related branches found
No related tags found
1 merge request!394Extend symbolic language support
...@@ -419,6 +419,10 @@ class PsCall(PsExpression): ...@@ -419,6 +419,10 @@ class PsCall(PsExpression):
if not isinstance(other, PsCall): if not isinstance(other, PsCall):
return False return False
return super().structurally_equal(other) and self._function == other._function return super().structurally_equal(other) and self._function == other._function
def __str__(self):
args = ", ".join(str(arg) for arg in self._args)
return f"PsCall({self._function}, ({args}))"
class PsTernary(PsExpression): class PsTernary(PsExpression):
......
...@@ -146,3 +146,15 @@ class PsMathFunction(PsFunction): ...@@ -146,3 +146,15 @@ class PsMathFunction(PsFunction):
@property @property
def func(self) -> MathFunctions: def func(self) -> MathFunctions:
return self._func return self._func
def __str__(self) -> str:
return f"{self._func.function_name}"
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsMathFunction):
return False
return self._func == other._func
def __hash__(self) -> int:
return hash(self._func)
...@@ -460,7 +460,7 @@ class FreezeExpressions: ...@@ -460,7 +460,7 @@ class FreezeExpressions:
while len(args) > 1: while len(args) > 1:
args = [ args = [
(PsCall(func, (args[i], args[i + 1])) if i + 1 < len(args) else args[i]) (PsCall(func, (args[i], args[i + 1])) if i + 1 < len(args) else args[i])
for i in range((len(args) + 1) // 2) for i in range(0, len(args), 2)
] ]
return cast(PsCall, args[0]) return cast(PsCall, args[0])
......
...@@ -25,9 +25,11 @@ from pystencils.backend.ast.expressions import ( ...@@ -25,9 +25,11 @@ from pystencils.backend.ast.expressions import (
PsLt, PsLt,
PsLe, PsLe,
PsGt, PsGt,
PsGe PsGe,
PsCall
) )
from pystencils.backend.constants import PsConstant from pystencils.backend.constants import PsConstant
from pystencils.backend.functions import PsMathFunction, MathFunctions
from pystencils.backend.kernelcreation import ( from pystencils.backend.kernelcreation import (
KernelCreationContext, KernelCreationContext,
FreezeExpressions, FreezeExpressions,
...@@ -227,3 +229,33 @@ def test_freeze_piecewise(): ...@@ -227,3 +229,33 @@ def test_freeze_piecewise():
piecewise = sp.Piecewise((x, p), (y, q), (z, sp.Or(p, q))) piecewise = sp.Piecewise((x, p), (y, q), (z, sp.Or(p, q)))
with pytest.raises(FreezeError): with pytest.raises(FreezeError):
freeze(piecewise) freeze(piecewise)
def test_multiarg_min_max():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
w, x, y, z = sp.symbols("w, x, y, z")
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
z2 = PsExpression.make(ctx.get_symbol("z"))
w2 = PsExpression.make(ctx.get_symbol("w"))
def op(a, b):
return PsCall(PsMathFunction(MathFunctions.Min), (a, b))
expr = freeze(sp.Min(w, x, y))
assert expr.structurally_equal(op(op(w2, x2), y2))
expr = freeze(sp.Min(w, x, y, z))
assert expr.structurally_equal(op(op(w2, x2), op(y2, z2)))
def op(a, b):
return PsCall(PsMathFunction(MathFunctions.Max), (a, b))
expr = freeze(sp.Max(w, x, y))
assert expr.structurally_equal(op(op(w2, x2), y2))
expr = freeze(sp.Max(w, x, y, z))
assert expr.structurally_equal(op(op(w2, x2), op(y2, z2)))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment