From 9ba8caa22f31473b97743bc7d39ac45f2cac1427 Mon Sep 17 00:00:00 2001
From: Behzad Safaei <iwia103h@a0522.nhr.fau.de>
Date: Sun, 4 May 2025 01:02:24 +0200
Subject: [PATCH] Fix serialized device kernels

---
 src/pairs/ir/loops.py                |  2 +-
 src/pairs/ir/math.py                 | 11 -----------
 src/pairs/transformations/devices.py |  9 +++++----
 3 files changed, 6 insertions(+), 16 deletions(-)

diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py
index 5544b7a..6108308 100644
--- a/src/pairs/ir/loops.py
+++ b/src/pairs/ir/loops.py
@@ -53,7 +53,7 @@ class For(ASTNode):
         super().__init__(sim)
         self.iterator = Iter(sim, self)
         self.min = Lit.cvt(sim, range_min)
-        self.max = Lit.cvt(sim, range_max)
+        self.max = ScalarOp.inline(Lit.cvt(sim, range_max))
         self.block = Block(sim, []) if block is None else block
         self.kernel = None
         self._kernel_candidate = False
diff --git a/src/pairs/ir/math.py b/src/pairs/ir/math.py
index d9bd573..9e7b85c 100644
--- a/src/pairs/ir/math.py
+++ b/src/pairs/ir/math.py
@@ -31,18 +31,7 @@ class MathFunction(ASTTerm):
         return "undefined"
 
     def inline_recursively(self):
-        method_name = "inline_recursively"
         self.inlined = True
-
-        if hasattr(self.cond, method_name) and callable(getattr(self.cond, method_name)):
-            self.cond.inline_recursively()
-
-        if hasattr(self.expr_if, method_name) and callable(getattr(self.expr_if, method_name)):
-            self.expr_if.inline_recursively()
-
-        if hasattr(self.expr_else, method_name) and callable(getattr(self.expr_else, method_name)):
-            self.expr_else.inline_recursively()
-
         return self
 
     def add_terminal(self, terminal):
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index 4c4ec15..c4b8885 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -89,20 +89,21 @@ class AddDeviceKernels(Mutator):
         self._kernel_id = 0
         self._device_module = False
 
-    def create_kernel(self, sim, iterator, rmax, block):
+    def create_kernel(self, sim, for_node):
         kernel_name = f"{self._module_name}_kernel{self._kernel_id}"
         kernel = sim.find_kernel_by_name(kernel_name)
 
         if kernel is None:
-            kernel_body = Filter(sim, ScalarOp.inline(iterator < rmax.copy(True)), block)
-            kernel = Kernel(sim, kernel_name, kernel_body, iterator)
+            kernel_body = for_node if for_node.serial else Filter(sim, 
+                    ScalarOp.inline(for_node.iterator < for_node.max.copy(True)), for_node.block)
+            kernel = Kernel(sim, kernel_name, kernel_body, for_node.iterator)
             self._kernel_id += 1
 
         return kernel
     
     def mutate_For(self, ast_node):
         if ast_node.is_kernel_candidate() and self._device_module:
-            kernel = self.create_kernel(ast_node.sim, ast_node.iterator, ast_node.max, ast_node.block)
+            kernel = self.create_kernel(ast_node.sim, ast_node)
             ast_node = KernelLaunch(ast_node.sim, kernel, ast_node.iterator, ast_node.min, ast_node.max, ast_node.serial)
 
         else:
-- 
GitLab