From 9ba8caa22f31473b97743bc7d39ac45f2cac1427 Mon Sep 17 00:00:00 2001 From: Behzad Safaei <iwia103h@a0522.nhr.fau.de> Date: Sun, 4 May 2025 01:02:24 +0200 Subject: [PATCH] Fix serialized device kernels --- src/pairs/ir/loops.py | 2 +- src/pairs/ir/math.py | 11 ----------- src/pairs/transformations/devices.py | 9 +++++---- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py index 5544b7a..6108308 100644 --- a/src/pairs/ir/loops.py +++ b/src/pairs/ir/loops.py @@ -53,7 +53,7 @@ class For(ASTNode): super().__init__(sim) self.iterator = Iter(sim, self) self.min = Lit.cvt(sim, range_min) - self.max = Lit.cvt(sim, range_max) + self.max = ScalarOp.inline(Lit.cvt(sim, range_max)) self.block = Block(sim, []) if block is None else block self.kernel = None self._kernel_candidate = False diff --git a/src/pairs/ir/math.py b/src/pairs/ir/math.py index d9bd573..9e7b85c 100644 --- a/src/pairs/ir/math.py +++ b/src/pairs/ir/math.py @@ -31,18 +31,7 @@ class MathFunction(ASTTerm): return "undefined" def inline_recursively(self): - method_name = "inline_recursively" self.inlined = True - - if hasattr(self.cond, method_name) and callable(getattr(self.cond, method_name)): - self.cond.inline_recursively() - - if hasattr(self.expr_if, method_name) and callable(getattr(self.expr_if, method_name)): - self.expr_if.inline_recursively() - - if hasattr(self.expr_else, method_name) and callable(getattr(self.expr_else, method_name)): - self.expr_else.inline_recursively() - return self def add_terminal(self, terminal): diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py index 4c4ec15..c4b8885 100644 --- a/src/pairs/transformations/devices.py +++ b/src/pairs/transformations/devices.py @@ -89,20 +89,21 @@ class AddDeviceKernels(Mutator): self._kernel_id = 0 self._device_module = False - def create_kernel(self, sim, iterator, rmax, block): + def create_kernel(self, sim, for_node): kernel_name = f"{self._module_name}_kernel{self._kernel_id}" kernel = sim.find_kernel_by_name(kernel_name) if kernel is None: - kernel_body = Filter(sim, ScalarOp.inline(iterator < rmax.copy(True)), block) - kernel = Kernel(sim, kernel_name, kernel_body, iterator) + kernel_body = for_node if for_node.serial else Filter(sim, + ScalarOp.inline(for_node.iterator < for_node.max.copy(True)), for_node.block) + kernel = Kernel(sim, kernel_name, kernel_body, for_node.iterator) self._kernel_id += 1 return kernel def mutate_For(self, ast_node): if ast_node.is_kernel_candidate() and self._device_module: - kernel = self.create_kernel(ast_node.sim, ast_node.iterator, ast_node.max, ast_node.block) + kernel = self.create_kernel(ast_node.sim, ast_node) ast_node = KernelLaunch(ast_node.sim, kernel, ast_node.iterator, ast_node.min, ast_node.max, ast_node.serial) else: -- GitLab