Skip to content
Snippets Groups Projects
Commit d93c549c authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Fix typing of floor/ceil

parent e24cf411
1 merge request!406Fix typing of floor/ceil
Pipeline #67656 passed with stages
in 22 minutes and 27 seconds
...@@ -11,6 +11,7 @@ from sympy.core.relational import Relational ...@@ -11,6 +11,7 @@ from sympy.core.relational import Relational
from sympy.functions.elementary.piecewise import ExprCondPair from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from sympy.functions.elementary.hyperbolic import HyperbolicFunction from sympy.functions.elementary.hyperbolic import HyperbolicFunction
from sympy.functions.elementary.integers import RoundFunction
from sympy.logic.boolalg import BooleanFunction from sympy.logic.boolalg import BooleanFunction
from sympy.logic.boolalg import BooleanAtom from sympy.logic.boolalg import BooleanAtom
...@@ -213,7 +214,7 @@ class TypeAdder: ...@@ -213,7 +214,7 @@ class TypeAdder:
new_args.append(a) new_args.append(a)
return expr.func(*new_args) if new_args else expr, collated_type return expr.func(*new_args) if new_args else expr, collated_type
elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction, elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction,
HyperbolicFunction, sp.log)): HyperbolicFunction, sp.log, RoundFunction)):
args_types = [self.figure_out_type(arg) for arg in expr.args] args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types]) collated_type = collate_types([t for _, t in args_types])
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
......
...@@ -39,7 +39,7 @@ def test_two_arguments(dtype, func, target): ...@@ -39,7 +39,7 @@ def test_two_arguments(dtype, func, target):
@pytest.mark.parametrize('dtype', ["float64", "float32"]) @pytest.mark.parametrize('dtype', ["float64", "float32"])
@pytest.mark.parametrize('func', [sp.sin, sp.cos, sp.sinh, sp.cosh, sp.atan]) @pytest.mark.parametrize('func', [sp.sin, sp.cos, sp.sinh, sp.cosh, sp.atan, sp.floor, sp.ceiling])
@pytest.mark.parametrize('target', [ps.Target.CPU, ps.Target.GPU]) @pytest.mark.parametrize('target', [ps.Target.CPU, ps.Target.GPU])
def test_single_arguments(dtype, func, target): def test_single_arguments(dtype, func, target):
if target == ps.Target.GPU: if target == ps.Target.GPU:
...@@ -58,7 +58,8 @@ def test_single_arguments(dtype, func, target): ...@@ -58,7 +58,8 @@ def test_single_arguments(dtype, func, target):
ast = ps.create_kernel(up, config=config) ast = ps.create_kernel(up, config=config)
code = ps.get_code_str(ast) code = ps.get_code_str(ast)
if dtype == 'float32': if dtype == 'float32':
assert func.__name__.lower() in code func_name = func.__name__.lower() if func is not sp.ceiling else "ceil"
assert func_name in code
kernel = ast.compile() kernel = ast.compile()
dh.all_to_gpu() dh.all_to_gpu()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment