diff --git a/src/pystencils/typing/leaf_typing.py b/src/pystencils/typing/leaf_typing.py index 9e7065b0a5960febbcc954f70040b9097031041e..c3ac54c59ac9fc15271b7522484acd2eb432428e 100644 --- a/src/pystencils/typing/leaf_typing.py +++ b/src/pystencils/typing/leaf_typing.py @@ -11,6 +11,7 @@ from sympy.core.relational import Relational from sympy.functions.elementary.piecewise import ExprCondPair from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction from sympy.functions.elementary.hyperbolic import HyperbolicFunction +from sympy.functions.elementary.integers import RoundFunction from sympy.logic.boolalg import BooleanFunction from sympy.logic.boolalg import BooleanAtom @@ -213,7 +214,7 @@ class TypeAdder: new_args.append(a) return expr.func(*new_args) if new_args else expr, collated_type elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction, - HyperbolicFunction, sp.log)): + HyperbolicFunction, sp.log, RoundFunction)): args_types = [self.figure_out_type(arg) for arg in expr.args] collated_type = collate_types([t for _, t in args_types]) new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] diff --git a/tests/test_math_functions.py b/tests/test_math_functions.py index 1fd39378847e6ec391c0e09ccfd56fac2a4e2e95..7f76c103918013e4deff0cda04f52d7b6e9e1019 100644 --- a/tests/test_math_functions.py +++ b/tests/test_math_functions.py @@ -39,7 +39,7 @@ def test_two_arguments(dtype, func, target): @pytest.mark.parametrize('dtype', ["float64", "float32"]) -@pytest.mark.parametrize('func', [sp.sin, sp.cos, sp.sinh, sp.cosh, sp.atan]) +@pytest.mark.parametrize('func', [sp.sin, sp.cos, sp.sinh, sp.cosh, sp.atan, sp.floor, sp.ceiling]) @pytest.mark.parametrize('target', [ps.Target.CPU, ps.Target.GPU]) def test_single_arguments(dtype, func, target): if target == ps.Target.GPU: @@ -58,7 +58,8 @@ def test_single_arguments(dtype, func, target): ast = ps.create_kernel(up, config=config) code = ps.get_code_str(ast) if dtype == 'float32': - assert func.__name__.lower() in code + func_name = func.__name__.lower() if func is not sp.ceiling else "ceil" + assert func_name in code kernel = ast.compile() dh.all_to_gpu()