Skip to content
Snippets Groups Projects

Add CudaBackend, CudaSympyPrinter

Merged Stephan Seitz requested to merge seitz/pystencils:eliminate-the-dialects into master
Files
9
@@ -32,7 +32,12 @@ __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSy
KERNCRAFT_NO_TERNARY_MODE = False
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str:
class UnsupportedCDialect(Exception):
def __init__(self):
super(UnsupportedCDialect, self).__init__()
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom_backend=None) -> str:
"""Prints an abstract syntax tree node as C or CUDA code.
This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
@@ -52,9 +57,16 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str
ast_node.global_variables.update(d.symbols_defined)
else:
ast_node.global_variables = d.symbols_defined
printer = CBackend(signature_only=signature_only,
vector_instruction_set=ast_node.instruction_set,
dialect=dialect)
if custom_backend:
printer = custom_backend
elif dialect == 'c':
printer = CBackend(signature_only=signature_only,
vector_instruction_set=ast_node.instruction_set)
elif dialect == 'cuda':
from pystencils.backends.cuda_backend import CudaBackend
printer = CudaBackend(signature_only=signature_only)
else:
raise UnsupportedCDialect
code = printer(ast_node)
if not signature_only and isinstance(ast_node, KernelFunction):
code = "\n" + code
@@ -141,9 +153,9 @@ class CBackend:
def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
if sympy_printer is None:
if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect)
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
else:
self.sympy_printer = CustomSympyPrinter(dialect)
self.sympy_printer = CustomSympyPrinter()
else:
self.sympy_printer = sympy_printer
@@ -164,12 +176,12 @@ class CBackend:
method_name = "_print_" + cls.__name__
if hasattr(self, method_name):
return getattr(self, method_name)(node)
raise NotImplementedError("CBackend does not support node of type " + str(type(node)))
raise NotImplementedError(self.__class__ + " does not support node of type " + str(type(node)))
def _print_KernelFunction(self, node):
function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
launch_bounds = ""
if self._dialect == 'cuda':
if self.__class__ == 'cuda':
max_threads = node.indexing.max_threads_per_block()
if max_threads:
launch_bounds = "__launch_bounds__({}) ".format(max_threads)
@@ -241,10 +253,7 @@ class CBackend:
return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
def _print_SkipIteration(self, _):
if self._dialect == 'cuda':
return "return;"
else:
return "continue;"
return "continue;"
def _print_CustomCodeNode(self, node):
return node.get_code(self._dialect, self._vector_instruction_set)
@@ -292,10 +301,9 @@ class CBackend:
# noinspection PyPep8Naming
class CustomSympyPrinter(CCodePrinter):
def __init__(self, dialect):
def __init__(self):
super(CustomSympyPrinter, self).__init__()
self._float_type = create_type("float32")
self._dialect = dialect
if 'Min' in self.known_functions:
del self.known_functions['Min']
if 'Max' in self.known_functions:
@@ -347,22 +355,13 @@ class CustomSympyPrinter(CCodePrinter):
else:
return "((%s)(%s))" % (data_type, self._print(arg))
elif isinstance(expr, fast_division):
if self._dialect == "cuda":
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
else:
return "({})".format(self._print(expr.args[0] / expr.args[1]))
return "({})".format(self._print(expr.args[0] / expr.args[1]))
elif isinstance(expr, fast_sqrt):
if self._dialect == "cuda":
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
else:
return "({})".format(self._print(sp.sqrt(expr.args[0])))
return "({})".format(self._print(sp.sqrt(expr.args[0])))
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
return self._print(expr.args[0])
elif isinstance(expr, fast_inv_sqrt):
if self._dialect == "cuda":
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
else:
return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
elif expr.func in infix_functions:
return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
elif expr.func == int_power_of_2:
@@ -392,8 +391,8 @@ class CustomSympyPrinter(CCodePrinter):
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
def __init__(self, instruction_set, dialect):
super(VectorizedCustomSympyPrinter, self).__init__(dialect=dialect)
def __init__(self, instruction_set):
super(VectorizedCustomSympyPrinter, self).__init__()
self.instruction_set = instruction_set
def _scalarFallback(self, func_name, expr, *args, **kwargs):
Loading