Skip to content
Snippets Groups Projects

Add CudaBackend, CudaSympyPrinter

Merged Stephan Seitz requested to merge seitz/pystencils:eliminate-the-dialects into master
Files
5
@@ -32,6 +32,11 @@ __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSy
@@ -32,6 +32,11 @@ __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSy
KERNCRAFT_NO_TERNARY_MODE = False
KERNCRAFT_NO_TERNARY_MODE = False
 
class UnsupportedCDialect(Exception):
 
def __init__(self):
 
super(UnsupportedCDialect, self).__init__()
 
 
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str:
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str:
"""Prints an abstract syntax tree node as C or CUDA code.
"""Prints an abstract syntax tree node as C or CUDA code.
@@ -52,9 +57,15 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str
@@ -52,9 +57,15 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str
ast_node.global_variables.update(d.symbols_defined)
ast_node.global_variables.update(d.symbols_defined)
else:
else:
ast_node.global_variables = d.symbols_defined
ast_node.global_variables = d.symbols_defined
printer = CBackend(signature_only=signature_only,
vector_instruction_set=ast_node.instruction_set,
if dialect == 'c':
dialect=dialect)
printer = CBackend(signature_only=signature_only,
 
vector_instruction_set=ast_node.instruction_set)
 
elif dialect == 'cuda':
 
from pystencils.backends.cuda_backend import CudaBackend
 
printer = CudaBackend(signature_only=signature_only)
 
else:
 
raise UnsupportedCDialect
code = printer(ast_node)
code = printer(ast_node)
if not signature_only and isinstance(ast_node, KernelFunction):
if not signature_only and isinstance(ast_node, KernelFunction):
code = "\n" + code
code = "\n" + code
@@ -141,9 +152,9 @@ class CBackend:
@@ -141,9 +152,9 @@ class CBackend:
def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
if sympy_printer is None:
if sympy_printer is None:
if vector_instruction_set is not None:
if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect)
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
else:
else:
self.sympy_printer = CustomSympyPrinter(dialect)
self.sympy_printer = CustomSympyPrinter()
else:
else:
self.sympy_printer = sympy_printer
self.sympy_printer = sympy_printer
@@ -164,12 +175,12 @@ class CBackend:
@@ -164,12 +175,12 @@ class CBackend:
method_name = "_print_" + cls.__name__
method_name = "_print_" + cls.__name__
if hasattr(self, method_name):
if hasattr(self, method_name):
return getattr(self, method_name)(node)
return getattr(self, method_name)(node)
raise NotImplementedError("CBackend does not support node of type " + str(type(node)))
raise NotImplementedError(self.__class__ + " does not support node of type " + str(type(node)))
def _print_KernelFunction(self, node):
def _print_KernelFunction(self, node):
function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
launch_bounds = ""
launch_bounds = ""
if self._dialect == 'cuda':
if self.__class__ == 'cuda':
max_threads = node.indexing.max_threads_per_block()
max_threads = node.indexing.max_threads_per_block()
if max_threads:
if max_threads:
launch_bounds = "__launch_bounds__({}) ".format(max_threads)
launch_bounds = "__launch_bounds__({}) ".format(max_threads)
@@ -241,10 +252,7 @@ class CBackend:
@@ -241,10 +252,7 @@ class CBackend:
return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
def _print_SkipIteration(self, _):
def _print_SkipIteration(self, _):
if self._dialect == 'cuda':
return "continue;"
return "return;"
else:
return "continue;"
def _print_CustomCodeNode(self, node):
def _print_CustomCodeNode(self, node):
return node.get_code(self._dialect, self._vector_instruction_set)
return node.get_code(self._dialect, self._vector_instruction_set)
@@ -292,10 +300,9 @@ class CBackend:
@@ -292,10 +300,9 @@ class CBackend:
# noinspection PyPep8Naming
# noinspection PyPep8Naming
class CustomSympyPrinter(CCodePrinter):
class CustomSympyPrinter(CCodePrinter):
def __init__(self, dialect):
def __init__(self):
super(CustomSympyPrinter, self).__init__()
super(CustomSympyPrinter, self).__init__()
self._float_type = create_type("float32")
self._float_type = create_type("float32")
self._dialect = dialect
if 'Min' in self.known_functions:
if 'Min' in self.known_functions:
del self.known_functions['Min']
del self.known_functions['Min']
if 'Max' in self.known_functions:
if 'Max' in self.known_functions:
@@ -347,22 +354,13 @@ class CustomSympyPrinter(CCodePrinter):
@@ -347,22 +354,13 @@ class CustomSympyPrinter(CCodePrinter):
else:
else:
return "((%s)(%s))" % (data_type, self._print(arg))
return "((%s)(%s))" % (data_type, self._print(arg))
elif isinstance(expr, fast_division):
elif isinstance(expr, fast_division):
if self._dialect == "cuda":
return "({})".format(self._print(expr.args[0] / expr.args[1]))
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
else:
return "({})".format(self._print(expr.args[0] / expr.args[1]))
elif isinstance(expr, fast_sqrt):
elif isinstance(expr, fast_sqrt):
if self._dialect == "cuda":
return "({})".format(self._print(sp.sqrt(expr.args[0])))
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
else:
return "({})".format(self._print(sp.sqrt(expr.args[0])))
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
return self._print(expr.args[0])
return self._print(expr.args[0])
elif isinstance(expr, fast_inv_sqrt):
elif isinstance(expr, fast_inv_sqrt):
if self._dialect == "cuda":
return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
else:
return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
elif expr.func in infix_functions:
elif expr.func in infix_functions:
return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
elif expr.func == int_power_of_2:
elif expr.func == int_power_of_2:
@@ -392,8 +390,8 @@ class CustomSympyPrinter(CCodePrinter):
@@ -392,8 +390,8 @@ class CustomSympyPrinter(CCodePrinter):
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
def __init__(self, instruction_set, dialect):
def __init__(self, instruction_set):
super(VectorizedCustomSympyPrinter, self).__init__(dialect=dialect)
super(VectorizedCustomSympyPrinter, self).__init__()
self.instruction_set = instruction_set
self.instruction_set = instruction_set
def _scalarFallback(self, func_name, expr, *args, **kwargs):
def _scalarFallback(self, func_name, expr, *args, **kwargs):
Loading