From 4106f9fdf0a4f1c0c6806ac3bca2a6ae467d31d1 Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Fri, 25 Apr 2025 15:57:33 +0200 Subject: [PATCH] Parameterize test_reduction_assignments with reduction ops --- tests/nbackend/kernelcreation/test_freeze.py | 26 +++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 987b68043..fe31cf94c 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -68,7 +68,13 @@ from pystencils.sympyextensions.integer_functions import ( ceil_to_multiple, div_ceil, ) -from pystencils.sympyextensions.reduction import AddReductionAssignment +from pystencils.sympyextensions.reduction import ( + AddReductionAssignment, + SubReductionAssignment, + MulReductionAssignment, + MinReductionAssignment, + MaxReductionAssignment, +) from pystencils.types import PsTypeError @@ -498,10 +504,22 @@ def test_invalid_arrays(): _ = freeze(symb_arr) -def test_reduction_assignments(): +@pytest.mark.parametrize("reduction_assignment_rhs_type", + [ + (AddReductionAssignment, PsAdd), + (SubReductionAssignment, PsSub), + (MulReductionAssignment, PsMul), + (MinReductionAssignment, PsCall), + (MaxReductionAssignment, PsCall), + ]) +def test_reduction_assignments( + reduction_assignment_rhs_type +): x = fields(f"x: float64[1d]") w = TypedSymbol("w", "float64") + reduction_op, rhs_type = reduction_assignment_rhs_type + ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) @@ -512,7 +530,7 @@ def test_reduction_assignments(): ) ctx.set_iteration_space(ispace) - expr = freeze(AddReductionAssignment(w, 3 * x.center())) + expr = freeze(reduction_op(w, 3 * x.center())) info = ctx.find_reduction_info(w.name) @@ -522,6 +540,8 @@ def test_reduction_assignments(): assert expr.lhs.symbol == info.local_symbol assert expr.lhs.dtype == w.dtype + assert isinstance(expr.rhs, rhs_type) + def test_invalid_reduction_assignments(): x = fields(f"x: float64[1d]") -- GitLab