Skip to content
Snippets Groups Projects
Commit 959241b0 authored by Jan Hönig's avatar Jan Hönig
Browse files

Merge branch 'RoundOffError' into 'master'

Fix RoundOff problems

See merge request pycodegen/pystencils!282
parents 29e0e84e 19852424
No related branches found
No related tags found
No related merge requests found
...@@ -441,7 +441,7 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -441,7 +441,7 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Pow(self, expr): def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication""" """Don't use std::pow function, for small integer exponents, write as multiplication"""
if not expr.free_symbols: if not expr.free_symbols:
return self._typed_number(expr.evalf(), get_type_of_expression(expr.base)) return self._typed_number(expr.evalf(17), get_type_of_expression(expr.base))
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})" return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})"
...@@ -452,7 +452,7 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -452,7 +452,7 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Rational(self, expr): def _print_Rational(self, expr):
"""Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0""" """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
res = str(expr.evalf().num) res = str(expr.evalf(17))
return res return res
def _print_Equality(self, expr): def _print_Equality(self, expr):
......
...@@ -234,7 +234,7 @@ def apply_sympy_optimisations(assignments): ...@@ -234,7 +234,7 @@ def apply_sympy_optimisations(assignments):
# Evaluates all constant terms # Evaluates all constant terms
evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer, evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
lambda p: p.evalf()) lambda p: p.evalf(17))
sympy_optimisations = [evaluate_constant_terms] + list(optims_c99) sympy_optimisations = [evaluate_constant_terms] + list(optims_c99)
......
...@@ -101,8 +101,9 @@ def test_sqrt_of_integer(): ...@@ -101,8 +101,9 @@ def test_sqrt_of_integer():
kernel(f=arr_single) kernel(f=arr_single)
code = ps.get_code_str(kernel.ast) code = ps.get_code_str(kernel.ast)
# ps.show_code(kernel.ast)
assert "1.7320508075688772f" in code # 1.7320508075688772935 --> it is actually correct to round to ...773. This was wrong before !282
assert "1.7320508075688773f" in code
assert 1.7 < arr_single[0] < 1.8 assert 1.7 < arr_single[0] < 1.8
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment