diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py index 5544b7a530220f16a4f48d485f7c5b4382a8a359..61083089af55e756714b238d9bcef5a939785c15 100644 --- a/src/pairs/ir/loops.py +++ b/src/pairs/ir/loops.py @@ -53,7 +53,7 @@ class For(ASTNode): super().__init__(sim) self.iterator = Iter(sim, self) self.min = Lit.cvt(sim, range_min) - self.max = Lit.cvt(sim, range_max) + self.max = ScalarOp.inline(Lit.cvt(sim, range_max)) self.block = Block(sim, []) if block is None else block self.kernel = None self._kernel_candidate = False diff --git a/src/pairs/ir/math.py b/src/pairs/ir/math.py index d9bd5730230d1f209b5c598f8b3ff55254f36496..9e7b85c1f4f96f41ff29067717038076d8a31168 100644 --- a/src/pairs/ir/math.py +++ b/src/pairs/ir/math.py @@ -31,18 +31,7 @@ class MathFunction(ASTTerm): return "undefined" def inline_recursively(self): - method_name = "inline_recursively" self.inlined = True - - if hasattr(self.cond, method_name) and callable(getattr(self.cond, method_name)): - self.cond.inline_recursively() - - if hasattr(self.expr_if, method_name) and callable(getattr(self.expr_if, method_name)): - self.expr_if.inline_recursively() - - if hasattr(self.expr_else, method_name) and callable(getattr(self.expr_else, method_name)): - self.expr_else.inline_recursively() - return self def add_terminal(self, terminal): diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py index 4c4ec154881fcbd71e57724036fdc62412c46fff..c4b8885c278d2d52a7c2e642a6ab0dcde704cf54 100644 --- a/src/pairs/transformations/devices.py +++ b/src/pairs/transformations/devices.py @@ -89,20 +89,21 @@ class AddDeviceKernels(Mutator): self._kernel_id = 0 self._device_module = False - def create_kernel(self, sim, iterator, rmax, block): + def create_kernel(self, sim, for_node): kernel_name = f"{self._module_name}_kernel{self._kernel_id}" kernel = sim.find_kernel_by_name(kernel_name) if kernel is None: - kernel_body = Filter(sim, ScalarOp.inline(iterator < rmax.copy(True)), block) - kernel = Kernel(sim, kernel_name, kernel_body, iterator) + kernel_body = for_node if for_node.serial else Filter(sim, + ScalarOp.inline(for_node.iterator < for_node.max.copy(True)), for_node.block) + kernel = Kernel(sim, kernel_name, kernel_body, for_node.iterator) self._kernel_id += 1 return kernel def mutate_For(self, ast_node): if ast_node.is_kernel_candidate() and self._device_module: - kernel = self.create_kernel(ast_node.sim, ast_node.iterator, ast_node.max, ast_node.block) + kernel = self.create_kernel(ast_node.sim, ast_node) ast_node = KernelLaunch(ast_node.sim, kernel, ast_node.iterator, ast_node.min, ast_node.max, ast_node.serial) else: