diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py
index 996a99ef93b4f499b9555af266573b999f4a1f56..d4552eba7183ab9a5371cb69e6bc6e07b22ed592 100644
--- a/src/pairs/analysis/devices.py
+++ b/src/pairs/analysis/devices.py
@@ -1,4 +1,5 @@
 from pairs.ir.arrays import ArrayAccess
+from pairs.ir.branches import Branch
 from pairs.ir.lit import Lit
 from pairs.ir.loops import For
 from pairs.ir.quaternions import QuaternionOp
@@ -13,10 +14,25 @@ class MarkCandidateLoops(Visitor):
         super().__init__(ast)
 
     def visit_Module(self, ast_node):
-        for s in ast_node._block.stmts:
-            if s is not None:
-                if isinstance(s, For) and (not isinstance(s.min, Lit) or not isinstance(s.max, Lit)):
-                    s.mark_as_kernel_candidate()
+        possible_candidates = []
+        for stmt in ast_node._block.stmts:
+            if stmt is not None:
+                if isinstance(stmt, Branch):
+                    for branch_stmt in stmt.block_if.stmts:
+                        if isinstance(branch_stmt, For):
+                            possible_candidates.append(branch_stmt)
+
+                    if stmt.block_else is not None:
+                        for branch_stmt in stmt.block_else.stmts:
+                            if isinstance(branch_stmt, For):
+                                possible_candidates.append(branch_stmt)
+
+                if isinstance(stmt, For):
+                    possible_candidates.append(stmt)
+
+        for stmt in possible_candidates:
+            if not isinstance(stmt.min, Lit) or not isinstance(stmt.max, Lit):
+                stmt.mark_as_kernel_candidate()
 
         self.visit_children(ast_node)
 
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index 9e53b95b439d7f4a4bc6e5239aacd3530d324dbc..e050f9ed52c82b253970ec35c17097b0f710fc93 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -1,7 +1,7 @@
 import math
 from pairs.ir.actions import Actions
 from pairs.ir.block import Block
-from pairs.ir.branches import Filter
+from pairs.ir.branches import Branch, Filter
 from pairs.ir.cast import Cast
 from pairs.ir.contexts import Contexts
 from pairs.ir.device import CopyArray, CopyContactProperty, CopyProperty, CopyVar, DeviceStaticRef, HostRef
@@ -86,33 +86,72 @@ class AddDeviceCopies(Mutator):
 class AddDeviceKernels(Mutator):
     def __init__(self, ast=None):
         super().__init__(ast)
+        self._module_name = None
+        self._kernel_id = 0
 
-    def mutate_Module(self, ast_node):
-        ast_node._block = self.mutate(ast_node._block)
+    def create_kernel(self, sim, iterator, rmax, block):
+        kernel_name = f"{self._module_name}_kernel{self._kernel_id}"
+        kernel = sim.find_kernel_by_name(kernel_name)
+
+        if kernel is None:
+            kernel_body = Filter(sim, ScalarOp.inline(iterator < rmax.copy(True)), block)
+            kernel = Kernel(sim, kernel_name, kernel_body, iterator)
+            self._kernel_id += 1
 
+        return kernel
+
+    def mutate_Module(self, ast_node):
         if ast_node.run_on_device:
+            self._module_name = ast_node.name
+            self._kernel_id = 0
+
             new_stmts = []
-            kernel_id = 0
-            for s in ast_node._block.stmts:
-                if s is not None:
-                    if isinstance(s, For) and s.is_kernel_candidate():
-                        kernel_name = f"{ast_node.name}_kernel{kernel_id}"
-                        kernel = ast_node.sim.find_kernel_by_name(kernel_name)
-                        if kernel is None:
-                            kernel_body = Filter(ast_node.sim,
-                                                 ScalarOp.inline(s.iterator < s.max.copy(True)),
-                                                 s.block)
-
-                            kernel = Kernel(ast_node.sim, kernel_name, kernel_body, s.iterator)
-                            kernel_id += 1
-
-                        new_stmts.append(KernelLaunch(ast_node.sim, kernel, s.iterator, s.min, s.max))
+            for stmt in ast_node._block.stmts:
+                if stmt is not None:
+                    if isinstance(stmt, For) and stmt.is_kernel_candidate():
+                        kernel = self.create_kernel(ast_node.sim, stmt.iterator, stmt.max, stmt.block)
+                        new_stmts.append(
+                            KernelLaunch(ast_node.sim, kernel, stmt.iterator, stmt.min, stmt.max))
 
                     else:
-                        new_stmts.append(s)
+                        if isinstance(stmt, Branch):
+                            stmt = self.check_and_mutate_branch(stmt)
+
+                        new_stmts.append(stmt)
 
             ast_node._block.stmts = new_stmts
 
+        ast_node._block = self.mutate(ast_node._block)
+        return ast_node
+
+    def check_and_mutate_branch(self, ast_node):
+        new_stmts = []
+        for stmt in ast_node.block_if.stmts:
+            if stmt is not None:
+                if isinstance(stmt, For) and stmt.is_kernel_candidate():
+                    kernel = self.create_kernel(ast_node.sim, stmt.iterator, stmt.max, stmt.block)
+                    new_stmts.append(
+                        KernelLaunch(ast_node.sim, kernel, stmt.iterator, stmt.min, stmt.max))
+
+                else:
+                    new_stmts.append(stmt)
+
+        ast_node.block_if.stmts = new_stmts
+
+        if ast_node.block_else is not None:
+            new_stmts = []
+            for stmt in ast_node.block_else.stmts:
+                if stmt is not None:
+                    if isinstance(stmt, For) and stmt.is_kernel_candidate():
+                        kernel = self.create_kernel(ast_node.sim, stmt.iterator, stmt.max, stmt.block)
+                        new_stmts.append(
+                            KernelLaunch(ast_node.sim, kernel, stmt.iterator, stmt.min, stmt.max))
+
+                    else:
+                        new_stmts.append(stmt)
+
+            ast_node.block_else.stmts = new_stmts
+
         return ast_node