From d93c549cc88c69c6b1cfa9cf06f4057654040f2e Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 17 Jul 2024 09:56:52 +0200
Subject: [PATCH] Fix typing of floor/ceil

---
 src/pystencils/typing/leaf_typing.py | 3 ++-
 tests/test_math_functions.py         | 5 +++--
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/src/pystencils/typing/leaf_typing.py b/src/pystencils/typing/leaf_typing.py
index 9e7065b0a..c3ac54c59 100644
--- a/src/pystencils/typing/leaf_typing.py
+++ b/src/pystencils/typing/leaf_typing.py
@@ -11,6 +11,7 @@ from sympy.core.relational import Relational
 from sympy.functions.elementary.piecewise import ExprCondPair
 from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
 from sympy.functions.elementary.hyperbolic import HyperbolicFunction
+from sympy.functions.elementary.integers import RoundFunction
 from sympy.logic.boolalg import BooleanFunction
 from sympy.logic.boolalg import BooleanAtom
 
@@ -213,7 +214,7 @@ class TypeAdder:
                     new_args.append(a)
             return expr.func(*new_args) if new_args else expr, collated_type
         elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction,
-                               HyperbolicFunction, sp.log)):
+                               HyperbolicFunction, sp.log, RoundFunction)):
             args_types = [self.figure_out_type(arg) for arg in expr.args]
             collated_type = collate_types([t for _, t in args_types])
             new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
diff --git a/tests/test_math_functions.py b/tests/test_math_functions.py
index 1fd393788..7f76c1039 100644
--- a/tests/test_math_functions.py
+++ b/tests/test_math_functions.py
@@ -39,7 +39,7 @@ def test_two_arguments(dtype, func, target):
 
 
 @pytest.mark.parametrize('dtype', ["float64", "float32"])
-@pytest.mark.parametrize('func', [sp.sin, sp.cos, sp.sinh, sp.cosh, sp.atan])
+@pytest.mark.parametrize('func', [sp.sin, sp.cos, sp.sinh, sp.cosh, sp.atan, sp.floor, sp.ceiling])
 @pytest.mark.parametrize('target', [ps.Target.CPU, ps.Target.GPU])
 def test_single_arguments(dtype, func, target):
     if target == ps.Target.GPU:
@@ -58,7 +58,8 @@ def test_single_arguments(dtype, func, target):
     ast = ps.create_kernel(up, config=config)
     code = ps.get_code_str(ast)
     if dtype == 'float32':
-        assert func.__name__.lower() in code
+        func_name = func.__name__.lower() if func is not sp.ceiling else "ceil"
+        assert func_name in code
     kernel = ast.compile()
 
     dh.all_to_gpu()
-- 
GitLab