Skip to content
Snippets Groups Projects

Basic support for OpenCL (experimental)

Merged Stephan Seitz requested to merge seitz/pystencils:opencl-backend into master
Files
10
@@ -58,6 +58,9 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom
elif dialect == 'cuda':
from pystencils.backends.cuda_backend import CudaBackend
printer = CudaBackend(signature_only=signature_only)
elif dialect == 'opencl':
from pystencils.backends.opencl_backend import OpenClBackend
printer = OpenClBackend(signature_only=signature_only)
else:
raise ValueError("Unknown dialect: " + str(dialect))
code = printer(ast_node)
@@ -165,14 +168,19 @@ class CBackend:
return result
def _print(self, node):
if isinstance(node, str):
return node
for cls in type(node).__mro__:
method_name = "_print_" + cls.__name__
if hasattr(self, method_name):
return getattr(self, method_name)(node)
raise NotImplementedError(self.__class__.__name__ + " does not support node of type " + node.__class__.__name__)
def _print_Type(self, node):
return str(node)
def _print_KernelFunction(self, node):
function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
function_arguments = ["%s %s" % (self._print(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
launch_bounds = ""
if self._dialect == 'cuda':
max_threads = node.indexing.max_threads_per_block()
@@ -207,7 +215,11 @@ class CBackend:
def _print_SympyAssignment(self, node):
if node.is_declaration:
data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
if node.is_const:
prefix = 'const '
else:
prefix = ''
data_type = prefix + self._print(node.lhs.dtype) + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs))
else:
@@ -276,10 +288,10 @@ class CBackend:
]
destructuring_bindings.sort() # only for code aesthetics
return "{\n" + self._indent + \
("\n" + self._indent).join(destructuring_bindings) + \
"\n" + self._indent + \
("\n" + self._indent).join(self._print(node.body).splitlines()) + \
"\n}"
("\n" + self._indent).join(destructuring_bindings) + \
"\n" + self._indent + \
("\n" + self._indent).join(self._print(node.body).splitlines()) + \
"\n}"
# ------------------------------------------ Helper function & classes -------------------------------------------------
Loading