Skip to content
Snippets Groups Projects
Commit d93c549c authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Fix typing of floor/ceil

parent e24cf411
Branches
Tags
1 merge request!406Fix typing of floor/ceil
Pipeline #67656 passed
...@@ -11,6 +11,7 @@ from sympy.core.relational import Relational ...@@ -11,6 +11,7 @@ from sympy.core.relational import Relational
from sympy.functions.elementary.piecewise import ExprCondPair from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from sympy.functions.elementary.hyperbolic import HyperbolicFunction from sympy.functions.elementary.hyperbolic import HyperbolicFunction
from sympy.functions.elementary.integers import RoundFunction
from sympy.logic.boolalg import BooleanFunction from sympy.logic.boolalg import BooleanFunction
from sympy.logic.boolalg import BooleanAtom from sympy.logic.boolalg import BooleanAtom
...@@ -213,7 +214,7 @@ class TypeAdder: ...@@ -213,7 +214,7 @@ class TypeAdder:
new_args.append(a) new_args.append(a)
return expr.func(*new_args) if new_args else expr, collated_type return expr.func(*new_args) if new_args else expr, collated_type
elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction, elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction,
HyperbolicFunction, sp.log)): HyperbolicFunction, sp.log, RoundFunction)):
args_types = [self.figure_out_type(arg) for arg in expr.args] args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types]) collated_type = collate_types([t for _, t in args_types])
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
......
...@@ -39,7 +39,7 @@ def test_two_arguments(dtype, func, target): ...@@ -39,7 +39,7 @@ def test_two_arguments(dtype, func, target):
@pytest.mark.parametrize('dtype', ["float64", "float32"]) @pytest.mark.parametrize('dtype', ["float64", "float32"])
@pytest.mark.parametrize('func', [sp.sin, sp.cos, sp.sinh, sp.cosh, sp.atan]) @pytest.mark.parametrize('func', [sp.sin, sp.cos, sp.sinh, sp.cosh, sp.atan, sp.floor, sp.ceiling])
@pytest.mark.parametrize('target', [ps.Target.CPU, ps.Target.GPU]) @pytest.mark.parametrize('target', [ps.Target.CPU, ps.Target.GPU])
def test_single_arguments(dtype, func, target): def test_single_arguments(dtype, func, target):
if target == ps.Target.GPU: if target == ps.Target.GPU:
...@@ -58,7 +58,8 @@ def test_single_arguments(dtype, func, target): ...@@ -58,7 +58,8 @@ def test_single_arguments(dtype, func, target):
ast = ps.create_kernel(up, config=config) ast = ps.create_kernel(up, config=config)
code = ps.get_code_str(ast) code = ps.get_code_str(ast)
if dtype == 'float32': if dtype == 'float32':
assert func.__name__.lower() in code func_name = func.__name__.lower() if func is not sp.ceiling else "ceil"
assert func_name in code
kernel = ast.compile() kernel = ast.compile()
dh.all_to_gpu() dh.all_to_gpu()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment