diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index a39a0a994618735ffbf04bdb74bdf16f29e37731..987b68043702d43889f1daa25d39478ba4fe3432 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -8,6 +8,7 @@ from pystencils import ( create_numeric_type, TypedSymbol, DynamicType, + KernelConstraintsError, ) from pystencils.sympyextensions import tcast from pystencils.sympyextensions.pointers import mem_acc @@ -68,6 +69,7 @@ from pystencils.sympyextensions.integer_functions import ( div_ceil, ) from pystencils.sympyextensions.reduction import AddReductionAssignment +from pystencils.types import PsTypeError def test_freeze_simple(): @@ -525,30 +527,34 @@ def test_invalid_reduction_assignments(): x = fields(f"x: float64[1d]") w = TypedSymbol("w", "float64") - ctx = KernelCreationContext() - freeze = FreezeExpressions(ctx) - - one = PsExpression.make(PsConstant(1, ctx.index_dtype)) - counter = ctx.get_symbol("ctr", ctx.index_dtype) - ispace = FullIterationSpace( - ctx, [FullIterationSpace.Dimension(one, one, one, counter)] - ) - ctx.set_iteration_space(ispace) - - invalid_assignment = Assignment(w, -1 * x.center()) + assignment = Assignment(w, -1 * x.center()) reduction_assignment = AddReductionAssignment(w, 3 * x.center()) - # reduction symbol is used before ReductionAssignment - with pytest.raises(FreezeError): - _ = [freeze(asm) for asm in [invalid_assignment, reduction_assignment]] + expected_errors_for_invalid_cases = [ + # 1) Reduction symbol is used before ReductionAssignment. + # May only be used for reductions -> KernelConstraintsError + ([assignment, reduction_assignment], KernelConstraintsError), + # 2) Reduction symbol is used after ReductionAssignment. + # Reduction symbol is converted to pointer after freeze -> PsTypeError + ([reduction_assignment, assignment], PsTypeError), + # 3) Duplicate ReductionAssignment + # May only be used once for now -> KernelConstraintsError + ([reduction_assignment, reduction_assignment], KernelConstraintsError) + ] - # reduction symbol is used after ReductionAssignment - with pytest.raises(FreezeError): - _ = [freeze(asm) for asm in [reduction_assignment, invalid_assignment]] + for invalid_assignment, error_class in expected_errors_for_invalid_cases: + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) - # duplicate ReductionAssignment - with pytest.raises(FreezeError): - _ = [freeze(asm) for asm in [reduction_assignment, reduction_assignment]] + one = PsExpression.make(PsConstant(1, ctx.index_dtype)) + counter = ctx.get_symbol("ctr", ctx.index_dtype) + ispace = FullIterationSpace( + ctx, [FullIterationSpace.Dimension(one, one, one, counter)] + ) + ctx.set_iteration_space(ispace) + + with pytest.raises(error_class): + _ = [freeze(asm) for asm in invalid_assignment] def test_memory_access():