diff --git a/pystencils/backends/syclbackend.py b/pystencils/backends/syclbackend.py index 493b035ec12ad34872b9c38acd89792660696826..17d357fe941fd07eade0e264f8089574e95a1d3e 100644 --- a/pystencils/backends/syclbackend.py +++ b/pystencils/backends/syclbackend.py @@ -6,7 +6,7 @@ import sympy as sp from pystencils.astnodes import Node, cast_func from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c -from pystencils.data_types import (Type, VectorType, get_type_of_expression, collate_types,) +from pystencils.data_types import Type, VectorType, collate_types, get_type_of_expression from pystencils.field import Field with open(join(dirname(__file__), "sycl_known_functions.txt")) as f: @@ -216,10 +216,10 @@ class SyCLSympyPrinter(CustomSympyPrinter): return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})" elif expr.exp.is_integer and -8 < expr.exp < 0: return super()._print_Pow(expr) - if get_type_of_expression(expr.exp) != base_type: - ret = pre_fixed_pow(sp.Pow(expr.base, cast_func(expr.exp, base_type))) - else: + if expr.exp == 0.5 or expr.exp == -0.5 or get_type_of_expression(expr.exp) == base_type: ret = pre_fixed_pow(expr) + else: + ret = pre_fixed_pow(sp.Pow(expr.base, cast_func(expr.exp, base_type))) try: number = float(ret)