Skip to content
Snippets Groups Projects
Commit e607e5bd authored by Martin Bauer's avatar Martin Bauer
Browse files

Formatting and style

parent c391dbb6
No related branches found
No related tags found
No related merge requests found
......@@ -22,18 +22,11 @@ try:
except ImportError:
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
KERNCRAFT_NO_TERNARY_MODE = False
class UnsupportedCDialect(Exception):
def __init__(self):
super(UnsupportedCDialect, self).__init__()
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom_backend=None) -> str:
"""Prints an abstract syntax tree node as C or CUDA code.
......@@ -63,7 +56,7 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom
from pystencils.backends.cuda_backend import CudaBackend
printer = CudaBackend(signature_only=signature_only)
else:
raise UnsupportedCDialect
raise ValueError("Unknown dialect: " + str(dialect))
code = printer(ast_node)
if not signature_only and isinstance(ast_node, KernelFunction):
code = "\n" + code
......
from os.path import dirname, join
from pystencils.astnodes import Node
from pystencils.backends.cbackend import (CBackend, CustomSympyPrinter,
generate_c)
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt,
fast_sqrt)
from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
CUDA_KNOWN_FUNCTIONS = None
with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
lines = f.readlines()
CUDA_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l}
......@@ -17,8 +13,8 @@ def generate_cuda(astnode: Node, signature_only: bool = False) -> str:
"""Prints an abstract syntax tree node as CUDA code.
Args:
ast_node:
signature_only:
astnode: KernelFunction node to generate code for
signature_only: if True only the signature is printed
Returns:
C-like code for the ast node and its descendants
......@@ -41,7 +37,8 @@ class CudaBackend(CBackend):
name=self.sympy_printer.doprint(node.symbol.name),
num_elements='*'.join([str(s) for s in node.shared_mem.shape]))
def _print_ThreadBlockSynchronization(self, node):
@staticmethod
def _print_ThreadBlockSynchronization(node):
code = "__synchtreads();"
return code
......
......@@ -60,8 +60,7 @@ from appdirs import user_cache_dir, user_config_dir
from pystencils import FieldType
from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.include import get_pystencils_include_path
from pystencils.utils import (atomic_file_write, file_handle_for_atomic_write,
recursive_dict_update)
from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update
def make_python_function(kernel_function_node, custom_backend=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment