Skip to content
Snippets Groups Projects
Commit 7644ab1f authored by Jan Hönig's avatar Jan Hönig
Browse files

Merge branch 'IntegerSquareRoot' into 'master'

Fixed integer square root

See merge request !274
parents 43393627 3bb88ad1
No related branches found
No related tags found
No related merge requests found
......@@ -444,7 +444,7 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
if not expr.free_symbols:
return self._typed_number(expr.evalf(), get_type_of_expression(expr))
return self._typed_number(expr.evalf(), get_type_of_expression(expr.base))
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})"
......
......@@ -87,10 +87,25 @@ def test_sqrt_of_integer():
assignments = [ps.Assignment(tmp, sp.sqrt(3)),
ps.Assignment(f[0], tmp)]
arr = np.array([1], dtype=np.float64)
arr_double = np.array([1], dtype=np.float64)
kernel = ps.create_kernel(assignments).compile()
kernel(f=arr)
assert 1.7 < arr[0] < 1.8
kernel(f=arr_double)
assert 1.7 < arr_double[0] < 1.8
f = ps.fields("f: float32[1D]")
tmp = sp.symbols("tmp")
assignments = [ps.Assignment(tmp, sp.sqrt(3)),
ps.Assignment(f[0], tmp)]
arr_single = np.array([1], dtype=np.float32)
config = ps.CreateKernelConfig(data_type="float32")
kernel = ps.create_kernel(assignments, config=config).compile()
kernel(f=arr_single)
code = ps.get_code_str(kernel.ast)
assert "1.7320508075688772f" in code
assert 1.7 < arr_single[0] < 1.8
def test_integer_comparision():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment