From d93c549cc88c69c6b1cfa9cf06f4057654040f2e Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 17 Jul 2024 09:56:52 +0200 Subject: [PATCH] Fix typing of floor/ceil --- src/pystencils/typing/leaf_typing.py | 3 ++- tests/test_math_functions.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/pystencils/typing/leaf_typing.py b/src/pystencils/typing/leaf_typing.py index 9e7065b0a..c3ac54c59 100644 --- a/src/pystencils/typing/leaf_typing.py +++ b/src/pystencils/typing/leaf_typing.py @@ -11,6 +11,7 @@ from sympy.core.relational import Relational from sympy.functions.elementary.piecewise import ExprCondPair from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction from sympy.functions.elementary.hyperbolic import HyperbolicFunction +from sympy.functions.elementary.integers import RoundFunction from sympy.logic.boolalg import BooleanFunction from sympy.logic.boolalg import BooleanAtom @@ -213,7 +214,7 @@ class TypeAdder: new_args.append(a) return expr.func(*new_args) if new_args else expr, collated_type elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction, - HyperbolicFunction, sp.log)): + HyperbolicFunction, sp.log, RoundFunction)): args_types = [self.figure_out_type(arg) for arg in expr.args] collated_type = collate_types([t for _, t in args_types]) new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] diff --git a/tests/test_math_functions.py b/tests/test_math_functions.py index 1fd393788..7f76c1039 100644 --- a/tests/test_math_functions.py +++ b/tests/test_math_functions.py @@ -39,7 +39,7 @@ def test_two_arguments(dtype, func, target): @pytest.mark.parametrize('dtype', ["float64", "float32"]) -@pytest.mark.parametrize('func', [sp.sin, sp.cos, sp.sinh, sp.cosh, sp.atan]) +@pytest.mark.parametrize('func', [sp.sin, sp.cos, sp.sinh, sp.cosh, sp.atan, sp.floor, sp.ceiling]) @pytest.mark.parametrize('target', [ps.Target.CPU, ps.Target.GPU]) def test_single_arguments(dtype, func, target): if target == ps.Target.GPU: @@ -58,7 +58,8 @@ def test_single_arguments(dtype, func, target): ast = ps.create_kernel(up, config=config) code = ps.get_code_str(ast) if dtype == 'float32': - assert func.__name__.lower() in code + func_name = func.__name__.lower() if func is not sp.ceiling else "ceil" + assert func_name in code kernel = ast.compile() dh.all_to_gpu() -- GitLab