Skip to content
Snippets Groups Projects
Commit 4106f9fd authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Parameterize test_reduction_assignments with reduction ops

parent 490ec914
No related branches found
No related tags found
1 merge request!438Reduction Support
......@@ -68,7 +68,13 @@ from pystencils.sympyextensions.integer_functions import (
ceil_to_multiple,
div_ceil,
)
from pystencils.sympyextensions.reduction import AddReductionAssignment
from pystencils.sympyextensions.reduction import (
AddReductionAssignment,
SubReductionAssignment,
MulReductionAssignment,
MinReductionAssignment,
MaxReductionAssignment,
)
from pystencils.types import PsTypeError
......@@ -498,10 +504,22 @@ def test_invalid_arrays():
_ = freeze(symb_arr)
def test_reduction_assignments():
@pytest.mark.parametrize("reduction_assignment_rhs_type",
[
(AddReductionAssignment, PsAdd),
(SubReductionAssignment, PsSub),
(MulReductionAssignment, PsMul),
(MinReductionAssignment, PsCall),
(MaxReductionAssignment, PsCall),
])
def test_reduction_assignments(
reduction_assignment_rhs_type
):
x = fields(f"x: float64[1d]")
w = TypedSymbol("w", "float64")
reduction_op, rhs_type = reduction_assignment_rhs_type
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
......@@ -512,7 +530,7 @@ def test_reduction_assignments():
)
ctx.set_iteration_space(ispace)
expr = freeze(AddReductionAssignment(w, 3 * x.center()))
expr = freeze(reduction_op(w, 3 * x.center()))
info = ctx.find_reduction_info(w.name)
......@@ -522,6 +540,8 @@ def test_reduction_assignments():
assert expr.lhs.symbol == info.local_symbol
assert expr.lhs.dtype == w.dtype
assert isinstance(expr.rhs, rhs_type)
def test_invalid_reduction_assignments():
x = fields(f"x: float64[1d]")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment