From 2e54c7a022fe6dd0d4f9984090c6c87c4aeae499 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 15 Nov 2024 15:37:49 +0100 Subject: [PATCH] Fix data type printing --- src/pystencilssfg/composer/basic_composer.py | 2 +- src/pystencilssfg/emission/printers.py | 8 ++++---- src/pystencilssfg/extensions/sycl.py | 4 ++-- src/pystencilssfg/ir/postprocessing.py | 10 +++++----- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 5dc1d7d..00d8b34 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -385,7 +385,7 @@ class SfgBasicComposer(SfgIComposer): args_str = ", ".join(str(arg) for arg in args) deps: set[SfgVar] = reduce(set.union, (depends(arg) for arg in args), set()) return SfgStatements( - f"{lhs_var.dtype} {lhs_var.name} {{ {args_str} }};", + f"{lhs_var.dtype.c_string()} {lhs_var.name} {{ {args_str} }};", (lhs_var,), deps, ) diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py index 9337161..c562bf7 100644 --- a/src/pystencilssfg/emission/printers.py +++ b/src/pystencilssfg/emission/printers.py @@ -66,7 +66,7 @@ class SfgGeneralPrinter: def param_list(self, func: SfgFunction) -> str: params = sorted(list(func.parameters), key=lambda p: p.name) - return ", ".join(f"{param.dtype} {param.name}" for param in params) + return ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params) class SfgHeaderPrinter(SfgGeneralPrinter): @@ -113,7 +113,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgFunction) def function(self, func: SfgFunction): params = sorted(list(func.parameters), key=lambda p: p.name) - param_list = ", ".join(f"{param.dtype} {param.name}" for param in params) + param_list = ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params) return f"{func.return_type} {func.name} ( {param_list} );" @visit.case(SfgClass) @@ -149,7 +149,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgConstructor) def sfg_constructor(self, constr: SfgConstructor): code = f"{constr.owning_class.class_name} (" - code += ", ".join(f"{param.dtype} {param.name}" for param in constr.parameters) + code += ", ".join(f"{param.dtype.c_string()} {param.name}" for param in constr.parameters) code += ")\n" if constr.initializers: code += " : " + ", ".join(constr.initializers) + "\n" @@ -161,7 +161,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgMemberVariable) def sfg_member_var(self, var: SfgMemberVariable): - return f"{var.dtype} {var.name};" + return f"{var.dtype.c_string()} {var.name};" @visit.case(SfgMethod) def sfg_method(self, method: SfgMethod): diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index dc80202..3cb0c1c 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -131,7 +131,7 @@ class SyclGroup(AugExpr): comp.map_param( id_param, h_item, - f"{id_param.dtype} {id_param.name} = {h_item}.get_local_id();", + f"{id_param.dtype.c_string()} {id_param.name} = {h_item}.get_local_id();", ), SfgKernelCallNode(kernel), ) @@ -186,7 +186,7 @@ class SfgLambda: def get_code(self, ctx: SfgContext): captures = ", ".join(self._captures) - params = ", ".join(f"{p.dtype} {p.name}" for p in self._params) + params = ", ".join(f"{p.dtype.c_string()} {p.name}" for p in self._params) body = self._tree.get_code(ctx) body = ctx.codestyle.indent(body) rtype = ( diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index e073356..638a55f 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -233,7 +233,7 @@ class SfgDeferredParamSetter(SfgDeferredNode): def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: live_var = ppc.get_live_variable(self._lhs.name) if live_var is not None: - code = f"{live_var.dtype} {live_var.name} = {self._rhs_expr};" + code = f"{live_var.dtype.c_string()} {live_var.name} = {self._rhs_expr};" return SfgStatements(code, (live_var,), tuple(self._depends)) else: return SfgSequence([]) @@ -291,7 +291,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): expr = self._extraction.ptr() nodes.append( SfgStatements( - f"{ptr.dtype} {ptr.name} {{ {expr} }};", (ptr,), expr.depends + f"{ptr.dtype.c_string()} {ptr.name} {{ {expr} }};", (ptr,), expr.depends ) ) @@ -313,7 +313,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): done.add(symb) expr = maybe_cast(expr, symb.dtype) return SfgStatements( - f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends + f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};", (symb,), expr.depends ) else: return SfgStatements(f"/* {expr} == {symb} */", (), ()) @@ -330,7 +330,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): done.add(symb) expr = maybe_cast(expr, symb.dtype) return SfgStatements( - f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends + f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};", (symb,), expr.depends ) else: return SfgStatements(f"/* {expr} == {symb} */", (), ()) @@ -355,7 +355,7 @@ class SfgDeferredVectorMapping(SfgDeferredNode): expr = self._vector.extract_component(idx) nodes.append( SfgStatements( - f"{param.dtype} {param.name} {{ {expr} }};", + f"{param.dtype.c_string()} {param.name} {{ {expr} }};", (param,), expr.depends, ) -- GitLab