From 6be4b5af798082acb82ab61623a0e3c2b8250c3b Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 3 Jul 2024 09:26:51 +0200 Subject: [PATCH] Fix multiarg min/max freeze, introduce test case, fix MathFunction equality --- src/pystencils/backend/ast/expressions.py | 4 +++ src/pystencils/backend/functions.py | 12 +++++++ .../backend/kernelcreation/freeze.py | 2 +- tests/nbackend/kernelcreation/test_freeze.py | 34 ++++++++++++++++++- 4 files changed, 50 insertions(+), 2 deletions(-) diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 5e6adfa4f..67fd6b27e 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -419,6 +419,10 @@ class PsCall(PsExpression): if not isinstance(other, PsCall): return False return super().structurally_equal(other) and self._function == other._function + + def __str__(self): + args = ", ".join(str(arg) for arg in self._args) + return f"PsCall({self._function}, ({args}))" class PsTernary(PsExpression): diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 80ea58d18..30b243d9c 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -146,3 +146,15 @@ class PsMathFunction(PsFunction): @property def func(self) -> MathFunctions: return self._func + + def __str__(self) -> str: + return f"{self._func.function_name}" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsMathFunction): + return False + + return self._func == other._func + + def __hash__(self) -> int: + return hash(self._func) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 1e1355bbf..3865db38f 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -460,7 +460,7 @@ class FreezeExpressions: while len(args) > 1: args = [ (PsCall(func, (args[i], args[i + 1])) if i + 1 < len(args) else args[i]) - for i in range((len(args) + 1) // 2) + for i in range(0, len(args), 2) ] return cast(PsCall, args[0]) diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index b1e8525b1..b22df7d0b 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -25,9 +25,11 @@ from pystencils.backend.ast.expressions import ( PsLt, PsLe, PsGt, - PsGe + PsGe, + PsCall ) from pystencils.backend.constants import PsConstant +from pystencils.backend.functions import PsMathFunction, MathFunctions from pystencils.backend.kernelcreation import ( KernelCreationContext, FreezeExpressions, @@ -227,3 +229,33 @@ def test_freeze_piecewise(): piecewise = sp.Piecewise((x, p), (y, q), (z, sp.Or(p, q))) with pytest.raises(FreezeError): freeze(piecewise) + + +def test_multiarg_min_max(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + w, x, y, z = sp.symbols("w, x, y, z") + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + z2 = PsExpression.make(ctx.get_symbol("z")) + w2 = PsExpression.make(ctx.get_symbol("w")) + + def op(a, b): + return PsCall(PsMathFunction(MathFunctions.Min), (a, b)) + + expr = freeze(sp.Min(w, x, y)) + assert expr.structurally_equal(op(op(w2, x2), y2)) + + expr = freeze(sp.Min(w, x, y, z)) + assert expr.structurally_equal(op(op(w2, x2), op(y2, z2))) + + def op(a, b): + return PsCall(PsMathFunction(MathFunctions.Max), (a, b)) + + expr = freeze(sp.Max(w, x, y)) + assert expr.structurally_equal(op(op(w2, x2), y2)) + + expr = freeze(sp.Max(w, x, y, z)) + assert expr.structurally_equal(op(op(w2, x2), op(y2, z2))) -- GitLab