From d1ee3d7f62e557655f8cd4265cbc227d68e645b0 Mon Sep 17 00:00:00 2001
From: Christoph Alt <christoph.alt@fau.de>
Date: Thu, 1 Dec 2022 17:55:19 +0100
Subject: [PATCH] dirty hack to print sqrt instead of pow( 0.5 )

---
 pystencils/backends/syclbackend.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/pystencils/backends/syclbackend.py b/pystencils/backends/syclbackend.py
index 493b035ec..17d357fe9 100644
--- a/pystencils/backends/syclbackend.py
+++ b/pystencils/backends/syclbackend.py
@@ -6,7 +6,7 @@ import sympy as sp
 
 from pystencils.astnodes import Node, cast_func
 from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
-from pystencils.data_types import (Type, VectorType, get_type_of_expression, collate_types,)
+from pystencils.data_types import Type, VectorType, collate_types, get_type_of_expression
 from pystencils.field import Field
 
 with open(join(dirname(__file__), "sycl_known_functions.txt")) as f:
@@ -216,10 +216,10 @@ class SyCLSympyPrinter(CustomSympyPrinter):
             return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})"
         elif expr.exp.is_integer and -8 < expr.exp < 0:
             return super()._print_Pow(expr)
-        if get_type_of_expression(expr.exp) != base_type:
-            ret = pre_fixed_pow(sp.Pow(expr.base, cast_func(expr.exp, base_type)))
-        else:
+        if expr.exp == 0.5 or expr.exp == -0.5 or get_type_of_expression(expr.exp) == base_type:
             ret = pre_fixed_pow(expr)
+        else:
+            ret = pre_fixed_pow(sp.Pow(expr.base, cast_func(expr.exp, base_type)))
 
         try:
             number = float(ret)
-- 
GitLab