diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index b081f752945133d56a4d32407e335b26d13614ec..6535b13fd791fdbe17fc6c82efa7c7dfc6657078 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -518,12 +518,13 @@ class LoopOverCoordinate(Node): class SympyAssignment(Node): - def __init__(self, lhs_symbol, rhs_expr, is_const=True): + def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=True): super(SympyAssignment, self).__init__(parent=None) self._lhs_symbol = lhs_symbol self.rhs = sp.sympify(rhs_expr) self._is_const = is_const self._is_declaration = self.__is_declaration() + self.use_auto = use_auto def __is_declaration(self): if isinstance(self._lhs_symbol, cast_func): diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 9a4095cc0a9f2a3d79c777e8f516c4f3dc7a6642..51fad62c29389c8c9dcfd34031c693b1b3bc43cf 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -5,6 +5,7 @@ import numpy as np import sympy as sp from sympy.core import S from sympy.printing.ccode import C89CodePrinter + from pystencils.astnodes import KernelFunction, Node from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.data_types import ( @@ -229,11 +230,15 @@ class CBackend: def _print_SympyAssignment(self, node): if node.is_declaration: - if node.is_const: - prefix = 'const ' + if node.use_auto: + data_type = 'auto ' else: - prefix = '' - data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " " + if node.is_const: + prefix = 'const ' + else: + prefix = '' + data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " " + return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))