diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 5dc1d7df8260fece0875511ebd1f29b778dfc11f..00d8b3436c5d23a34ea5e3fb4db60e6d5d83fb5c 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -385,7 +385,7 @@ class SfgBasicComposer(SfgIComposer): args_str = ", ".join(str(arg) for arg in args) deps: set[SfgVar] = reduce(set.union, (depends(arg) for arg in args), set()) return SfgStatements( - f"{lhs_var.dtype} {lhs_var.name} {{ {args_str} }};", + f"{lhs_var.dtype.c_string()} {lhs_var.name} {{ {args_str} }};", (lhs_var,), deps, ) diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py index 93371619e302932d97441ff069190ba956b0f04d..c562bf7bca5d59365c3a1cd5b9d77b8f7e7920f5 100644 --- a/src/pystencilssfg/emission/printers.py +++ b/src/pystencilssfg/emission/printers.py @@ -66,7 +66,7 @@ class SfgGeneralPrinter: def param_list(self, func: SfgFunction) -> str: params = sorted(list(func.parameters), key=lambda p: p.name) - return ", ".join(f"{param.dtype} {param.name}" for param in params) + return ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params) class SfgHeaderPrinter(SfgGeneralPrinter): @@ -113,7 +113,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgFunction) def function(self, func: SfgFunction): params = sorted(list(func.parameters), key=lambda p: p.name) - param_list = ", ".join(f"{param.dtype} {param.name}" for param in params) + param_list = ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params) return f"{func.return_type} {func.name} ( {param_list} );" @visit.case(SfgClass) @@ -149,7 +149,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgConstructor) def sfg_constructor(self, constr: SfgConstructor): code = f"{constr.owning_class.class_name} (" - code += ", ".join(f"{param.dtype} {param.name}" for param in constr.parameters) + code += ", ".join(f"{param.dtype.c_string()} {param.name}" for param in constr.parameters) code += ")\n" if constr.initializers: code += " : " + ", ".join(constr.initializers) + "\n" @@ -161,7 +161,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgMemberVariable) def sfg_member_var(self, var: SfgMemberVariable): - return f"{var.dtype} {var.name};" + return f"{var.dtype.c_string()} {var.name};" @visit.case(SfgMethod) def sfg_method(self, method: SfgMethod): diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index dc80202427036b8563f133666f71c5b4e60ec4ec..3cb0c1c5e50aa2b9557a176f3c541283641ad530 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -131,7 +131,7 @@ class SyclGroup(AugExpr): comp.map_param( id_param, h_item, - f"{id_param.dtype} {id_param.name} = {h_item}.get_local_id();", + f"{id_param.dtype.c_string()} {id_param.name} = {h_item}.get_local_id();", ), SfgKernelCallNode(kernel), ) @@ -186,7 +186,7 @@ class SfgLambda: def get_code(self, ctx: SfgContext): captures = ", ".join(self._captures) - params = ", ".join(f"{p.dtype} {p.name}" for p in self._params) + params = ", ".join(f"{p.dtype.c_string()} {p.name}" for p in self._params) body = self._tree.get_code(ctx) body = ctx.codestyle.indent(body) rtype = ( diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index e0733569a3fa2342cecd114e30968f4c2944a84e..638a55f30f41f26f531a69a346b083dddd901797 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -233,7 +233,7 @@ class SfgDeferredParamSetter(SfgDeferredNode): def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: live_var = ppc.get_live_variable(self._lhs.name) if live_var is not None: - code = f"{live_var.dtype} {live_var.name} = {self._rhs_expr};" + code = f"{live_var.dtype.c_string()} {live_var.name} = {self._rhs_expr};" return SfgStatements(code, (live_var,), tuple(self._depends)) else: return SfgSequence([]) @@ -291,7 +291,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): expr = self._extraction.ptr() nodes.append( SfgStatements( - f"{ptr.dtype} {ptr.name} {{ {expr} }};", (ptr,), expr.depends + f"{ptr.dtype.c_string()} {ptr.name} {{ {expr} }};", (ptr,), expr.depends ) ) @@ -313,7 +313,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): done.add(symb) expr = maybe_cast(expr, symb.dtype) return SfgStatements( - f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends + f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};", (symb,), expr.depends ) else: return SfgStatements(f"/* {expr} == {symb} */", (), ()) @@ -330,7 +330,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): done.add(symb) expr = maybe_cast(expr, symb.dtype) return SfgStatements( - f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends + f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};", (symb,), expr.depends ) else: return SfgStatements(f"/* {expr} == {symb} */", (), ()) @@ -355,7 +355,7 @@ class SfgDeferredVectorMapping(SfgDeferredNode): expr = self._vector.extract_component(idx) nodes.append( SfgStatements( - f"{param.dtype} {param.name} {{ {expr} }};", + f"{param.dtype.c_string()} {param.name} {{ {expr} }};", (param,), expr.depends, )