diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 987b68043702d43889f1daa25d39478ba4fe3432..fe31cf94cdfe5b07780baab166c812551f866701 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -68,7 +68,13 @@ from pystencils.sympyextensions.integer_functions import ( ceil_to_multiple, div_ceil, ) -from pystencils.sympyextensions.reduction import AddReductionAssignment +from pystencils.sympyextensions.reduction import ( + AddReductionAssignment, + SubReductionAssignment, + MulReductionAssignment, + MinReductionAssignment, + MaxReductionAssignment, +) from pystencils.types import PsTypeError @@ -498,10 +504,22 @@ def test_invalid_arrays(): _ = freeze(symb_arr) -def test_reduction_assignments(): +@pytest.mark.parametrize("reduction_assignment_rhs_type", + [ + (AddReductionAssignment, PsAdd), + (SubReductionAssignment, PsSub), + (MulReductionAssignment, PsMul), + (MinReductionAssignment, PsCall), + (MaxReductionAssignment, PsCall), + ]) +def test_reduction_assignments( + reduction_assignment_rhs_type +): x = fields(f"x: float64[1d]") w = TypedSymbol("w", "float64") + reduction_op, rhs_type = reduction_assignment_rhs_type + ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) @@ -512,7 +530,7 @@ def test_reduction_assignments(): ) ctx.set_iteration_space(ispace) - expr = freeze(AddReductionAssignment(w, 3 * x.center())) + expr = freeze(reduction_op(w, 3 * x.center())) info = ctx.find_reduction_info(w.name) @@ -522,6 +540,8 @@ def test_reduction_assignments(): assert expr.lhs.symbol == info.local_symbol assert expr.lhs.dtype == w.dtype + assert isinstance(expr.rhs, rhs_type) + def test_invalid_reduction_assignments(): x = fields(f"x: float64[1d]")