diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 2ed480c5566493b5dbe109093bb1ba79ae4ffe2f..495b5b976c39fb5b7c8f9efb860dfe1464fb0589 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -89,6 +89,11 @@ class Simulation:
     def kernels(self):
         return self.kernel_list
 
+    def find_kernel_by_name(self, name):
+        matches = [k for k in self.kernel_list if k.name == name]
+        assert len(matches) < 2, "find_kernel_by_name(): More than one match for kernel name!"
+        return matches[0] if len(matches) == 1 else None
+
     def ndims(self):
         return self.dims
 
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index a5b18f647033317709e36044c71226706c3f8964..b77a0d45b818d33dfd56f6cb192df5d09cfa34c8 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -72,12 +72,14 @@ class AddDeviceKernels(Mutator):
             for s in ast_node._block.stmts:
                 if s is not None:
                     if isinstance(s, For) and (not isinstance(s.min, Lit) or not isinstance(s.max, Lit)):
-                        if s.kernel is None:
-                            s.kernel = Kernel(ast_node.sim, f"{ast_node.name}_kernel{kernel_id}",
-                                              Filter(ast_node.sim, BinOp.inline(s.iterator < s.max), s.block), s.iterator)
+                        kernel_name = f"{ast_node.name}_kernel{kernel_id}"
+                        kernel = ast_node.sim.find_kernel_by_name(kernel_name)
+                        if kernel is None:
+                            kernel_body = Filter(ast_node.sim, BinOp.inline(s.iterator < s.max), s.block)
+                            kernel = Kernel(ast_node.sim, kernel_name, kernel_body, s.iterator)
                             kernel_id += 1
 
-                        new_stmts.append(KernelLaunch(ast_node.sim, s.kernel, s.iterator, s.min, s.max))
+                        new_stmts.append(KernelLaunch(ast_node.sim, kernel, s.iterator, s.min, s.max))
 
                     else:
                         new_stmts.append(s)