diff --git a/tests/kernelcreation/test_type_cast.py b/tests/kernelcreation/test_type_cast.py index a88adbf78e0b16f56fbd05f70316004f7c4356dc..8ad6d867042ff6e57e5baee8c12ff45bae17e8e4 100644 --- a/tests/kernelcreation/test_type_cast.py +++ b/tests/kernelcreation/test_type_cast.py @@ -61,8 +61,13 @@ def test_type_cast(gen_config, xp, from_type, to_type): kfunc = kernel.compile() kfunc(inp=inp, outp=outp) - # rounding mode depends on platform - try: + if np.issubdtype(from_type, np.floating) and not np.issubdtype( + to_type, np.floating + ): + # rounding mode depends on platform + try: + xp.testing.assert_array_equal(outp, truncated) + except AssertionError: + xp.testing.assert_array_equal(outp, rounded) + else: xp.testing.assert_array_equal(outp, truncated) - except AssertionError: - xp.testing.assert_array_equal(outp, rounded)