Skip to content
Snippets Groups Projects
Commit 9ba8caa2 authored by Behzad Safaei's avatar Behzad Safaei
Browse files

Fix serialized device kernels

parent b3928430
No related branches found
No related tags found
No related merge requests found
......@@ -53,7 +53,7 @@ class For(ASTNode):
super().__init__(sim)
self.iterator = Iter(sim, self)
self.min = Lit.cvt(sim, range_min)
self.max = Lit.cvt(sim, range_max)
self.max = ScalarOp.inline(Lit.cvt(sim, range_max))
self.block = Block(sim, []) if block is None else block
self.kernel = None
self._kernel_candidate = False
......
......@@ -31,18 +31,7 @@ class MathFunction(ASTTerm):
return "undefined"
def inline_recursively(self):
method_name = "inline_recursively"
self.inlined = True
if hasattr(self.cond, method_name) and callable(getattr(self.cond, method_name)):
self.cond.inline_recursively()
if hasattr(self.expr_if, method_name) and callable(getattr(self.expr_if, method_name)):
self.expr_if.inline_recursively()
if hasattr(self.expr_else, method_name) and callable(getattr(self.expr_else, method_name)):
self.expr_else.inline_recursively()
return self
def add_terminal(self, terminal):
......
......@@ -89,20 +89,21 @@ class AddDeviceKernels(Mutator):
self._kernel_id = 0
self._device_module = False
def create_kernel(self, sim, iterator, rmax, block):
def create_kernel(self, sim, for_node):
kernel_name = f"{self._module_name}_kernel{self._kernel_id}"
kernel = sim.find_kernel_by_name(kernel_name)
if kernel is None:
kernel_body = Filter(sim, ScalarOp.inline(iterator < rmax.copy(True)), block)
kernel = Kernel(sim, kernel_name, kernel_body, iterator)
kernel_body = for_node if for_node.serial else Filter(sim,
ScalarOp.inline(for_node.iterator < for_node.max.copy(True)), for_node.block)
kernel = Kernel(sim, kernel_name, kernel_body, for_node.iterator)
self._kernel_id += 1
return kernel
def mutate_For(self, ast_node):
if ast_node.is_kernel_candidate() and self._device_module:
kernel = self.create_kernel(ast_node.sim, ast_node.iterator, ast_node.max, ast_node.block)
kernel = self.create_kernel(ast_node.sim, ast_node)
ast_node = KernelLaunch(ast_node.sim, kernel, ast_node.iterator, ast_node.min, ast_node.max, ast_node.serial)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment