diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 5e6adfa4ff73d0a813f8668d634a27a698180aea..67fd6b27ea9bad0b037affa63913193560ea6ed2 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -419,6 +419,10 @@ class PsCall(PsExpression): if not isinstance(other, PsCall): return False return super().structurally_equal(other) and self._function == other._function + + def __str__(self): + args = ", ".join(str(arg) for arg in self._args) + return f"PsCall({self._function}, ({args}))" class PsTernary(PsExpression): diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 80ea58d189dcb43d9bca62263bef6e42311de168..30b243d9cd614d9f843021dff52167297fccdbba 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -146,3 +146,15 @@ class PsMathFunction(PsFunction): @property def func(self) -> MathFunctions: return self._func + + def __str__(self) -> str: + return f"{self._func.function_name}" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsMathFunction): + return False + + return self._func == other._func + + def __hash__(self) -> int: + return hash(self._func) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 1e1355bbfda1827efb1b7b2134ff7d764a647919..3865db38fe603a6cf5fe4d31deef1743d4276bd6 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -460,7 +460,7 @@ class FreezeExpressions: while len(args) > 1: args = [ (PsCall(func, (args[i], args[i + 1])) if i + 1 < len(args) else args[i]) - for i in range((len(args) + 1) // 2) + for i in range(0, len(args), 2) ] return cast(PsCall, args[0]) diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index b1e8525b1c6c8b1180a3ff345e2cbb38766fc9c9..b22df7d0bd132cc530e289b630f9c48851e4996b 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -25,9 +25,11 @@ from pystencils.backend.ast.expressions import ( PsLt, PsLe, PsGt, - PsGe + PsGe, + PsCall ) from pystencils.backend.constants import PsConstant +from pystencils.backend.functions import PsMathFunction, MathFunctions from pystencils.backend.kernelcreation import ( KernelCreationContext, FreezeExpressions, @@ -227,3 +229,33 @@ def test_freeze_piecewise(): piecewise = sp.Piecewise((x, p), (y, q), (z, sp.Or(p, q))) with pytest.raises(FreezeError): freeze(piecewise) + + +def test_multiarg_min_max(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + w, x, y, z = sp.symbols("w, x, y, z") + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + z2 = PsExpression.make(ctx.get_symbol("z")) + w2 = PsExpression.make(ctx.get_symbol("w")) + + def op(a, b): + return PsCall(PsMathFunction(MathFunctions.Min), (a, b)) + + expr = freeze(sp.Min(w, x, y)) + assert expr.structurally_equal(op(op(w2, x2), y2)) + + expr = freeze(sp.Min(w, x, y, z)) + assert expr.structurally_equal(op(op(w2, x2), op(y2, z2))) + + def op(a, b): + return PsCall(PsMathFunction(MathFunctions.Max), (a, b)) + + expr = freeze(sp.Max(w, x, y)) + assert expr.structurally_equal(op(op(w2, x2), y2)) + + expr = freeze(sp.Max(w, x, y, z)) + assert expr.structurally_equal(op(op(w2, x2), op(y2, z2)))