diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 76ee03f5838b199b3977bceb3ab398d687b9e8a2..77c6c6ba06cbc39d2d3ad6a865d4c607d5eca200 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -5,6 +5,7 @@ import numpy as np import sympy as sp from sympy.core import S from sympy.printing.ccode import C89CodePrinter + from pystencils.astnodes import KernelFunction, Node from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.data_types import ( @@ -344,6 +345,9 @@ class CustomSympyPrinter(CCodePrinter): result = super(CustomSympyPrinter, self)._print_Piecewise(expr) return result.replace("\n", "") + def _print_Type(self, node): + return str(node) + def _print_Function(self, expr): infix_functions = { bitwise_xor: '^', @@ -356,7 +360,7 @@ class CustomSympyPrinter(CCodePrinter): return expr.to_c(self._print) if isinstance(expr, reinterpret_cast_func): arg, data_type = expr.args - return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg)) + return "*((%s)(& %s))" % (self._print(PointerType(data_type, restrict=False)), self._print(arg)) elif isinstance(expr, address_of): assert len(expr.args) == 1, "address_of must only have one argument" return "&(%s)" % self._print(expr.args[0]) diff --git a/pystencils/backends/opencl_backend.py b/pystencils/backends/opencl_backend.py index b5da806bb7d4993e4d10fc9615902171d455da1b..a21b3bf489e04d595654f4fb6a3799be20fdf891 100644 --- a/pystencils/backends/opencl_backend.py +++ b/pystencils/backends/opencl_backend.py @@ -68,6 +68,13 @@ class OpenClSympyPrinter(CudaSympyPrinter): CustomSympyPrinter.__init__(self) self.known_functions = OPENCL_KNOWN_FUNCTIONS + def _print_Type(self, node): + code = super()._print_Type(node) + if isinstance(node, pystencils.data_types.PointerType): + return "__global " + code + else: + return code + def _print_ThreadIndexingSymbol(self, node): symbol_name: str = node.name function_name, dimension = tuple(symbol_name.split("."))