diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py
index 5544b7a530220f16a4f48d485f7c5b4382a8a359..61083089af55e756714b238d9bcef5a939785c15 100644
--- a/src/pairs/ir/loops.py
+++ b/src/pairs/ir/loops.py
@@ -53,7 +53,7 @@ class For(ASTNode):
         super().__init__(sim)
         self.iterator = Iter(sim, self)
         self.min = Lit.cvt(sim, range_min)
-        self.max = Lit.cvt(sim, range_max)
+        self.max = ScalarOp.inline(Lit.cvt(sim, range_max))
         self.block = Block(sim, []) if block is None else block
         self.kernel = None
         self._kernel_candidate = False
diff --git a/src/pairs/ir/math.py b/src/pairs/ir/math.py
index d9bd5730230d1f209b5c598f8b3ff55254f36496..9e7b85c1f4f96f41ff29067717038076d8a31168 100644
--- a/src/pairs/ir/math.py
+++ b/src/pairs/ir/math.py
@@ -31,18 +31,7 @@ class MathFunction(ASTTerm):
         return "undefined"
 
     def inline_recursively(self):
-        method_name = "inline_recursively"
         self.inlined = True
-
-        if hasattr(self.cond, method_name) and callable(getattr(self.cond, method_name)):
-            self.cond.inline_recursively()
-
-        if hasattr(self.expr_if, method_name) and callable(getattr(self.expr_if, method_name)):
-            self.expr_if.inline_recursively()
-
-        if hasattr(self.expr_else, method_name) and callable(getattr(self.expr_else, method_name)):
-            self.expr_else.inline_recursively()
-
         return self
 
     def add_terminal(self, terminal):
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index 4c4ec154881fcbd71e57724036fdc62412c46fff..c4b8885c278d2d52a7c2e642a6ab0dcde704cf54 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -89,20 +89,21 @@ class AddDeviceKernels(Mutator):
         self._kernel_id = 0
         self._device_module = False
 
-    def create_kernel(self, sim, iterator, rmax, block):
+    def create_kernel(self, sim, for_node):
         kernel_name = f"{self._module_name}_kernel{self._kernel_id}"
         kernel = sim.find_kernel_by_name(kernel_name)
 
         if kernel is None:
-            kernel_body = Filter(sim, ScalarOp.inline(iterator < rmax.copy(True)), block)
-            kernel = Kernel(sim, kernel_name, kernel_body, iterator)
+            kernel_body = for_node if for_node.serial else Filter(sim, 
+                    ScalarOp.inline(for_node.iterator < for_node.max.copy(True)), for_node.block)
+            kernel = Kernel(sim, kernel_name, kernel_body, for_node.iterator)
             self._kernel_id += 1
 
         return kernel
     
     def mutate_For(self, ast_node):
         if ast_node.is_kernel_candidate() and self._device_module:
-            kernel = self.create_kernel(ast_node.sim, ast_node.iterator, ast_node.max, ast_node.block)
+            kernel = self.create_kernel(ast_node.sim, ast_node)
             ast_node = KernelLaunch(ast_node.sim, kernel, ast_node.iterator, ast_node.min, ast_node.max, ast_node.serial)
 
         else: