Skip to content
Snippets Groups Projects
Commit c5cdc763 authored by Markus Holzer's avatar Markus Holzer
Browse files

Fix: vectorization of float sqrt

parent e685321c
No related merge requests found
...@@ -391,6 +391,8 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -391,6 +391,8 @@ class CustomSympyPrinter(CCodePrinter):
if isinstance(arg, sp.Number) and arg.is_finite: if isinstance(arg, sp.Number) and arg.is_finite:
return self._typed_number(arg, data_type) return self._typed_number(arg, data_type)
else: else:
if str(arg) == "-1":
print("!!")
return "((%s)(%s))" % (data_type, self._print(arg)) return "((%s)(%s))" % (data_type, self._print(arg))
elif isinstance(expr, fast_division): elif isinstance(expr, fast_division):
return "({})".format(self._print(expr.args[0] / expr.args[1])) return "({})".format(self._print(expr.args[0] / expr.args[1]))
......
...@@ -550,7 +550,7 @@ def get_type_of_expression(expr, ...@@ -550,7 +550,7 @@ def get_type_of_expression(expr,
return result return result
elif isinstance(expr, sp.Pow): elif isinstance(expr, sp.Pow):
base_type = get_type(expr.args[0]) base_type = get_type(expr.args[0])
if expr.exp.is_integer: if expr.exp.is_integer or expr.exp == sp.Rational(1, 2):
return base_type return base_type
else: else:
return collate_types([create_type(default_float_type), base_type]) return collate_types([create_type(default_float_type), base_type])
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment