diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 6367807f9fb8c0632053504196d3486a0b6a3174..5be413f32a9661db2ecb21d02e4dff95630b9848 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -353,6 +353,9 @@ class CustomSympyPrinter(CCodePrinter): result = super(CustomSympyPrinter, self)._print_Piecewise(expr) return result.replace("\n", "") + def _print_Type(self, node): + return str(node) + def _print_Function(self, expr): infix_functions = { bitwise_xor: '^', @@ -365,7 +368,7 @@ class CustomSympyPrinter(CCodePrinter): return expr.to_c(self._print) if isinstance(expr, reinterpret_cast_func): arg, data_type = expr.args - return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg)) + return "*((%s)(& %s))" % (self._print(PointerType(data_type, restrict=False)), self._print(arg)) elif isinstance(expr, address_of): assert len(expr.args) == 1, "address_of must only have one argument" return "&(%s)" % self._print(expr.args[0]) diff --git a/pystencils/backends/opencl_backend.py b/pystencils/backends/opencl_backend.py index b5da806bb7d4993e4d10fc9615902171d455da1b..a21b3bf489e04d595654f4fb6a3799be20fdf891 100644 --- a/pystencils/backends/opencl_backend.py +++ b/pystencils/backends/opencl_backend.py @@ -68,6 +68,13 @@ class OpenClSympyPrinter(CudaSympyPrinter): CustomSympyPrinter.__init__(self) self.known_functions = OPENCL_KNOWN_FUNCTIONS + def _print_Type(self, node): + code = super()._print_Type(node) + if isinstance(node, pystencils.data_types.PointerType): + return "__global " + code + else: + return code + def _print_ThreadIndexingSymbol(self, node): symbol_name: str = node.name function_name, dimension = tuple(symbol_name.split("."))