Skip to content
Snippets Groups Projects
Commit f875fbc0 authored by Martin Bauer's avatar Martin Bauer
Browse files

Simplify constant square roots

parent d5c24d0f
No related branches found
No related tags found
No related merge requests found
...@@ -299,6 +299,9 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -299,6 +299,9 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Pow(self, expr): def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication""" """Don't use std::pow function, for small integer exponents, write as multiplication"""
if not expr.free_symbols:
return self._typed_number(expr.evalf(), get_type_of_expression(expr))
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
......
...@@ -70,6 +70,9 @@ class cast_func(sp.Function): ...@@ -70,6 +70,9 @@ class cast_func(sp.Function):
def is_commutative(self): def is_commutative(self):
return self.args[0].is_commutative return self.args[0].is_commutative
def _eval_evalf(self, *args, **kwargs):
return self.args[0].evalf()
@property @property
def dtype(self): def dtype(self):
return self.args[1] return self.args[1]
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment