From 2e54c7a022fe6dd0d4f9984090c6c87c4aeae499 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 15 Nov 2024 15:37:49 +0100
Subject: [PATCH] Fix data type printing

---
 src/pystencilssfg/composer/basic_composer.py |  2 +-
 src/pystencilssfg/emission/printers.py       |  8 ++++----
 src/pystencilssfg/extensions/sycl.py         |  4 ++--
 src/pystencilssfg/ir/postprocessing.py       | 10 +++++-----
 4 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py
index 5dc1d7d..00d8b34 100644
--- a/src/pystencilssfg/composer/basic_composer.py
+++ b/src/pystencilssfg/composer/basic_composer.py
@@ -385,7 +385,7 @@ class SfgBasicComposer(SfgIComposer):
             args_str = ", ".join(str(arg) for arg in args)
             deps: set[SfgVar] = reduce(set.union, (depends(arg) for arg in args), set())
             return SfgStatements(
-                f"{lhs_var.dtype} {lhs_var.name} {{ {args_str} }};",
+                f"{lhs_var.dtype.c_string()} {lhs_var.name} {{ {args_str} }};",
                 (lhs_var,),
                 deps,
             )
diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py
index 9337161..c562bf7 100644
--- a/src/pystencilssfg/emission/printers.py
+++ b/src/pystencilssfg/emission/printers.py
@@ -66,7 +66,7 @@ class SfgGeneralPrinter:
 
     def param_list(self, func: SfgFunction) -> str:
         params = sorted(list(func.parameters), key=lambda p: p.name)
-        return ", ".join(f"{param.dtype} {param.name}" for param in params)
+        return ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params)
 
 
 class SfgHeaderPrinter(SfgGeneralPrinter):
@@ -113,7 +113,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter):
     @visit.case(SfgFunction)
     def function(self, func: SfgFunction):
         params = sorted(list(func.parameters), key=lambda p: p.name)
-        param_list = ", ".join(f"{param.dtype} {param.name}" for param in params)
+        param_list = ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params)
         return f"{func.return_type} {func.name} ( {param_list} );"
 
     @visit.case(SfgClass)
@@ -149,7 +149,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter):
     @visit.case(SfgConstructor)
     def sfg_constructor(self, constr: SfgConstructor):
         code = f"{constr.owning_class.class_name} ("
-        code += ", ".join(f"{param.dtype} {param.name}" for param in constr.parameters)
+        code += ", ".join(f"{param.dtype.c_string()} {param.name}" for param in constr.parameters)
         code += ")\n"
         if constr.initializers:
             code += "  : " + ", ".join(constr.initializers) + "\n"
@@ -161,7 +161,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter):
 
     @visit.case(SfgMemberVariable)
     def sfg_member_var(self, var: SfgMemberVariable):
-        return f"{var.dtype} {var.name};"
+        return f"{var.dtype.c_string()} {var.name};"
 
     @visit.case(SfgMethod)
     def sfg_method(self, method: SfgMethod):
diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py
index dc80202..3cb0c1c 100644
--- a/src/pystencilssfg/extensions/sycl.py
+++ b/src/pystencilssfg/extensions/sycl.py
@@ -131,7 +131,7 @@ class SyclGroup(AugExpr):
             comp.map_param(
                 id_param,
                 h_item,
-                f"{id_param.dtype} {id_param.name} = {h_item}.get_local_id();",
+                f"{id_param.dtype.c_string()} {id_param.name} = {h_item}.get_local_id();",
             ),
             SfgKernelCallNode(kernel),
         )
@@ -186,7 +186,7 @@ class SfgLambda:
 
     def get_code(self, ctx: SfgContext):
         captures = ", ".join(self._captures)
-        params = ", ".join(f"{p.dtype} {p.name}" for p in self._params)
+        params = ", ".join(f"{p.dtype.c_string()} {p.name}" for p in self._params)
         body = self._tree.get_code(ctx)
         body = ctx.codestyle.indent(body)
         rtype = (
diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py
index e073356..638a55f 100644
--- a/src/pystencilssfg/ir/postprocessing.py
+++ b/src/pystencilssfg/ir/postprocessing.py
@@ -233,7 +233,7 @@ class SfgDeferredParamSetter(SfgDeferredNode):
     def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
         live_var = ppc.get_live_variable(self._lhs.name)
         if live_var is not None:
-            code = f"{live_var.dtype} {live_var.name} = {self._rhs_expr};"
+            code = f"{live_var.dtype.c_string()} {live_var.name} = {self._rhs_expr};"
             return SfgStatements(code, (live_var,), tuple(self._depends))
         else:
             return SfgSequence([])
@@ -291,7 +291,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
             expr = self._extraction.ptr()
             nodes.append(
                 SfgStatements(
-                    f"{ptr.dtype} {ptr.name} {{ {expr} }};", (ptr,), expr.depends
+                    f"{ptr.dtype.c_string()} {ptr.name} {{ {expr} }};", (ptr,), expr.depends
                 )
             )
 
@@ -313,7 +313,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
                 done.add(symb)
                 expr = maybe_cast(expr, symb.dtype)
                 return SfgStatements(
-                    f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends
+                    f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};", (symb,), expr.depends
                 )
             else:
                 return SfgStatements(f"/* {expr} == {symb} */", (), ())
@@ -330,7 +330,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
                 done.add(symb)
                 expr = maybe_cast(expr, symb.dtype)
                 return SfgStatements(
-                    f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends
+                    f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};", (symb,), expr.depends
                 )
             else:
                 return SfgStatements(f"/* {expr} == {symb} */", (), ())
@@ -355,7 +355,7 @@ class SfgDeferredVectorMapping(SfgDeferredNode):
                 expr = self._vector.extract_component(idx)
                 nodes.append(
                     SfgStatements(
-                        f"{param.dtype} {param.name} {{ {expr} }};",
+                        f"{param.dtype.c_string()} {param.name} {{ {expr} }};",
                         (param,),
                         expr.depends,
                     )
-- 
GitLab