Skip to content
Snippets Groups Projects
Commit 3bb88ad1 authored by Markus Holzer's avatar Markus Holzer
Browse files

Fixed integer square root

parent 43393627
No related branches found
No related tags found
1 merge request!274Fixed integer square root
......@@ -444,7 +444,7 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
if not expr.free_symbols:
return self._typed_number(expr.evalf(), get_type_of_expression(expr))
return self._typed_number(expr.evalf(), get_type_of_expression(expr.base))
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})"
......
......@@ -87,10 +87,25 @@ def test_sqrt_of_integer():
assignments = [ps.Assignment(tmp, sp.sqrt(3)),
ps.Assignment(f[0], tmp)]
arr = np.array([1], dtype=np.float64)
arr_double = np.array([1], dtype=np.float64)
kernel = ps.create_kernel(assignments).compile()
kernel(f=arr)
assert 1.7 < arr[0] < 1.8
kernel(f=arr_double)
assert 1.7 < arr_double[0] < 1.8
f = ps.fields("f: float32[1D]")
tmp = sp.symbols("tmp")
assignments = [ps.Assignment(tmp, sp.sqrt(3)),
ps.Assignment(f[0], tmp)]
arr_single = np.array([1], dtype=np.float32)
config = ps.CreateKernelConfig(data_type="float32")
kernel = ps.create_kernel(assignments, config=config).compile()
kernel(f=arr_single)
code = ps.get_code_str(kernel.ast)
assert "1.7320508075688772f" in code
assert 1.7 < arr_single[0] < 1.8
def test_integer_comparision():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment