Skip to content
Snippets Groups Projects
Commit e607e5bd authored by Martin Bauer's avatar Martin Bauer
Browse files

Formatting and style

parent c391dbb6
No related branches found
No related tags found
No related merge requests found
...@@ -22,18 +22,11 @@ try: ...@@ -22,18 +22,11 @@ try:
except ImportError: except ImportError:
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1 from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
KERNCRAFT_NO_TERNARY_MODE = False KERNCRAFT_NO_TERNARY_MODE = False
class UnsupportedCDialect(Exception):
def __init__(self):
super(UnsupportedCDialect, self).__init__()
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom_backend=None) -> str: def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom_backend=None) -> str:
"""Prints an abstract syntax tree node as C or CUDA code. """Prints an abstract syntax tree node as C or CUDA code.
...@@ -63,7 +56,7 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom ...@@ -63,7 +56,7 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom
from pystencils.backends.cuda_backend import CudaBackend from pystencils.backends.cuda_backend import CudaBackend
printer = CudaBackend(signature_only=signature_only) printer = CudaBackend(signature_only=signature_only)
else: else:
raise UnsupportedCDialect raise ValueError("Unknown dialect: " + str(dialect))
code = printer(ast_node) code = printer(ast_node)
if not signature_only and isinstance(ast_node, KernelFunction): if not signature_only and isinstance(ast_node, KernelFunction):
code = "\n" + code code = "\n" + code
......
from os.path import dirname, join from os.path import dirname, join
from pystencils.astnodes import Node from pystencils.astnodes import Node
from pystencils.backends.cbackend import (CBackend, CustomSympyPrinter, from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
generate_c) from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt,
fast_sqrt)
CUDA_KNOWN_FUNCTIONS = None
with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f: with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
lines = f.readlines() lines = f.readlines()
CUDA_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l} CUDA_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l}
...@@ -17,8 +13,8 @@ def generate_cuda(astnode: Node, signature_only: bool = False) -> str: ...@@ -17,8 +13,8 @@ def generate_cuda(astnode: Node, signature_only: bool = False) -> str:
"""Prints an abstract syntax tree node as CUDA code. """Prints an abstract syntax tree node as CUDA code.
Args: Args:
ast_node: astnode: KernelFunction node to generate code for
signature_only: signature_only: if True only the signature is printed
Returns: Returns:
C-like code for the ast node and its descendants C-like code for the ast node and its descendants
...@@ -41,7 +37,8 @@ class CudaBackend(CBackend): ...@@ -41,7 +37,8 @@ class CudaBackend(CBackend):
name=self.sympy_printer.doprint(node.symbol.name), name=self.sympy_printer.doprint(node.symbol.name),
num_elements='*'.join([str(s) for s in node.shared_mem.shape])) num_elements='*'.join([str(s) for s in node.shared_mem.shape]))
def _print_ThreadBlockSynchronization(self, node): @staticmethod
def _print_ThreadBlockSynchronization(node):
code = "__synchtreads();" code = "__synchtreads();"
return code return code
......
...@@ -60,8 +60,7 @@ from appdirs import user_cache_dir, user_config_dir ...@@ -60,8 +60,7 @@ from appdirs import user_cache_dir, user_config_dir
from pystencils import FieldType from pystencils import FieldType
from pystencils.backends.cbackend import generate_c, get_headers from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.include import get_pystencils_include_path from pystencils.include import get_pystencils_include_path
from pystencils.utils import (atomic_file_write, file_handle_for_atomic_write, from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update
recursive_dict_update)
def make_python_function(kernel_function_node, custom_backend=None): def make_python_function(kernel_function_node, custom_backend=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment