Skip to content
Snippets Groups Projects
Commit 7644ab1f authored by Jan Hönig's avatar Jan Hönig
Browse files

Merge branch 'IntegerSquareRoot' into 'master'

Fixed integer square root

See merge request !274
parents 43393627 3bb88ad1
No related branches found
No related tags found
1 merge request!274Fixed integer square root
Pipeline #35674 failed
...@@ -444,7 +444,7 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -444,7 +444,7 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Pow(self, expr): def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication""" """Don't use std::pow function, for small integer exponents, write as multiplication"""
if not expr.free_symbols: if not expr.free_symbols:
return self._typed_number(expr.evalf(), get_type_of_expression(expr)) return self._typed_number(expr.evalf(), get_type_of_expression(expr.base))
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})" return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})"
......
...@@ -87,10 +87,25 @@ def test_sqrt_of_integer(): ...@@ -87,10 +87,25 @@ def test_sqrt_of_integer():
assignments = [ps.Assignment(tmp, sp.sqrt(3)), assignments = [ps.Assignment(tmp, sp.sqrt(3)),
ps.Assignment(f[0], tmp)] ps.Assignment(f[0], tmp)]
arr = np.array([1], dtype=np.float64) arr_double = np.array([1], dtype=np.float64)
kernel = ps.create_kernel(assignments).compile() kernel = ps.create_kernel(assignments).compile()
kernel(f=arr) kernel(f=arr_double)
assert 1.7 < arr[0] < 1.8 assert 1.7 < arr_double[0] < 1.8
f = ps.fields("f: float32[1D]")
tmp = sp.symbols("tmp")
assignments = [ps.Assignment(tmp, sp.sqrt(3)),
ps.Assignment(f[0], tmp)]
arr_single = np.array([1], dtype=np.float32)
config = ps.CreateKernelConfig(data_type="float32")
kernel = ps.create_kernel(assignments, config=config).compile()
kernel(f=arr_single)
code = ps.get_code_str(kernel.ast)
assert "1.7320508075688772f" in code
assert 1.7 < arr_single[0] < 1.8
def test_integer_comparision(): def test_integer_comparision():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment