Skip to content
Snippets Groups Projects

Auto for assignments

Merged Stephan Seitz requested to merge seitz/pystencils:auto-for-assignments into master
Files
3
@@ -5,6 +5,7 @@ import numpy as np
@@ -5,6 +5,7 @@ import numpy as np
import sympy as sp
import sympy as sp
from sympy.core import S
from sympy.core import S
from sympy.printing.ccode import C89CodePrinter
from sympy.printing.ccode import C89CodePrinter
 
from pystencils.astnodes import KernelFunction, Node
from pystencils.astnodes import KernelFunction, Node
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import (
from pystencils.data_types import (
@@ -32,9 +33,9 @@ def generate_c(ast_node: Node,
@@ -32,9 +33,9 @@ def generate_c(ast_node: Node,
with_globals=True) -> str:
with_globals=True) -> str:
"""Prints an abstract syntax tree node as C or CUDA code.
"""Prints an abstract syntax tree node as C or CUDA code.
This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
This function does not need to distinguish for most AST nodes between C, C++ or CUDA code, it just prints 'C-like'
in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel
code as encoded in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different
functions.
create_kernel functions.
Args:
Args:
ast_node:
ast_node:
@@ -229,11 +230,15 @@ class CBackend:
@@ -229,11 +230,15 @@ class CBackend:
def _print_SympyAssignment(self, node):
def _print_SympyAssignment(self, node):
if node.is_declaration:
if node.is_declaration:
if node.is_const:
if node.use_auto:
prefix = 'const '
data_type = 'auto '
else:
else:
prefix = ''
if node.is_const:
data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
prefix = 'const '
 
else:
 
prefix = ''
 
data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
 
return "%s%s = %s;" % (data_type,
return "%s%s = %s;" % (data_type,
self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs))
self.sympy_printer.doprint(node.rhs))
Loading