diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 559295930911c390f90cddafadb4d0c8da2b4b66..9efca1f53a6870d2b0ce2bdd1da547d16beb6a6c 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -353,6 +353,9 @@ class CustomSympyPrinter(CCodePrinter): result = super(CustomSympyPrinter, self)._print_Piecewise(expr) return result.replace("\n", "") + def _print_Type(self, node): + return str(node) + def _print_Function(self, expr): infix_functions = { bitwise_xor: '^', @@ -365,7 +368,7 @@ class CustomSympyPrinter(CCodePrinter): return expr.to_c(self._print) if isinstance(expr, reinterpret_cast_func): arg, data_type = expr.args - return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg)) + return "*((%s)(& %s))" % (self._print(PointerType(data_type, restrict=False)), self._print(arg)) elif isinstance(expr, address_of): assert len(expr.args) == 1, "address_of must only have one argument" return "&(%s)" % self._print(expr.args[0]) diff --git a/pystencils/backends/opencl_backend.py b/pystencils/backends/opencl_backend.py index 60ea06be9723005654155b256f35e094db0005f6..3ab7f820ea30adebf584977625b2e559f897ca27 100644 --- a/pystencils/backends/opencl_backend.py +++ b/pystencils/backends/opencl_backend.py @@ -68,6 +68,13 @@ class OpenClSympyPrinter(CudaSympyPrinter): CustomSympyPrinter.__init__(self) self.known_functions = OPENCL_KNOWN_FUNCTIONS + def _print_Type(self, node): + code = super()._print_Type(node) + if isinstance(node, pystencils.data_types.PointerType): + return "__global " + code + else: + return code + def _print_ThreadIndexingSymbol(self, node): symbol_name: str = node.name function_name, dimension = tuple(symbol_name.split("."))