Skip to content
Snippets Groups Projects
Commit 69c29833 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add option to use `auto` in `SympyAssignment`s

parent 0b8e7e88
No related branches found
No related tags found
No related merge requests found
...@@ -518,12 +518,13 @@ class LoopOverCoordinate(Node): ...@@ -518,12 +518,13 @@ class LoopOverCoordinate(Node):
class SympyAssignment(Node): class SympyAssignment(Node):
def __init__(self, lhs_symbol, rhs_expr, is_const=True): def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=True):
super(SympyAssignment, self).__init__(parent=None) super(SympyAssignment, self).__init__(parent=None)
self._lhs_symbol = lhs_symbol self._lhs_symbol = lhs_symbol
self.rhs = sp.sympify(rhs_expr) self.rhs = sp.sympify(rhs_expr)
self._is_const = is_const self._is_const = is_const
self._is_declaration = self.__is_declaration() self._is_declaration = self.__is_declaration()
self.use_auto = use_auto
def __is_declaration(self): def __is_declaration(self):
if isinstance(self._lhs_symbol, cast_func): if isinstance(self._lhs_symbol, cast_func):
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import sympy as sp import sympy as sp
from sympy.core import S from sympy.core import S
from sympy.printing.ccode import C89CodePrinter from sympy.printing.ccode import C89CodePrinter
from pystencils.astnodes import KernelFunction, Node from pystencils.astnodes import KernelFunction, Node
from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import ( from pystencils.data_types import (
...@@ -229,11 +230,15 @@ class CBackend: ...@@ -229,11 +230,15 @@ class CBackend:
def _print_SympyAssignment(self, node): def _print_SympyAssignment(self, node):
if node.is_declaration: if node.is_declaration:
if node.is_const: if node.use_auto:
prefix = 'const ' data_type = 'auto '
else: else:
prefix = '' if node.is_const:
data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " " prefix = 'const '
else:
prefix = ''
data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
return "%s%s = %s;" % (data_type, return "%s%s = %s;" % (data_type,
self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs)) self.sympy_printer.doprint(node.rhs))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment