diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index a7f39967dde062baed774597c71a7b32044fb4f5..f6827b830af5011be5c4e88566a2bd756c3b0b35 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -144,6 +144,7 @@ class CGen:
             kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
 
         self.print(f"__global__ void {kernel.name}({kernel_params}) {{")
+        self.print(f"    const int {kernel.iterator.name()} = blockIdx.x * blockDim.x + threadIdx.x;")
         self.generate_statement(kernel.block)
         self.print("}")
 
@@ -270,19 +271,19 @@ class CGen:
         if isinstance(ast_node, KernelLaunch):
             kernel = ast_node.kernel
             kernel_params = ""
-            for var in module.read_only_variables():
+            for var in kernel.read_only_variables():
                 decl = var.name()
                 kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
 
-            for var in module.write_variables():
+            for var in kernel.write_variables():
                 decl = f"&{var.name()}"
                 kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
 
-            for array in module.arrays():
+            for array in kernel.arrays():
                 decl = f"d_{array.name()}"
                 kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
 
-            for prop in module.properties():
+            for prop in kernel.properties():
                 decl = f"d_{prop.name()}"
                 kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
 
@@ -290,10 +291,9 @@ class CGen:
                 decl = self.generate_expression(bin_op)
                 kernel_params += decl if len(kernel_params) <= 0 else f", {decl}"
 
-            elems = ast_node.kernel.max - ast_node.kernel.min
-            threads_per_block = self.generate_expression(ast_node.kernel.threads_per_block)
-            blocks = self.generate_expression((elems + threads_per_block - 1) // threads_per_block)
-            self.print(f"{kernel.name}<<<{blocks}, {threads_per_block}>>>({kernel_params});")
+            elems = ast_node.max - ast_node.min
+            blocks = self.generate_expression((elems + ast_node.threads_per_block - 1) / ast_node.threads_per_block)
+            self.print(f"{kernel.name}<<<{blocks}, {ast_node.threads_per_block}>>>({kernel_params});")
 
         if isinstance(ast_node, ModuleCall):
             module = ast_node.module
diff --git a/src/pairs/ir/kernel.py b/src/pairs/ir/kernel.py
index b1cdf0f7eacb5f89ea2a08889d90a50caa5873b5..334f707bf757968f7006e861bc069597ff5d5aea 100644
--- a/src/pairs/ir/kernel.py
+++ b/src/pairs/ir/kernel.py
@@ -1,5 +1,6 @@
 from pairs.ir.arrays import Array
 from pairs.ir.ast_node import ASTNode
+from pairs.ir.bin_op import BinOp
 from pairs.ir.properties import Property
 from pairs.ir.variables import Var
 
@@ -7,7 +8,7 @@ from pairs.ir.variables import Var
 class Kernel(ASTNode):
     last_kernel = 0
 
-    def __init__(self, sim, name=None, block=None):
+    def __init__(self, sim, name=None, block=None, iterator=None):
         super().__init__(sim)
         self._id = Kernel.last_kernel
         self._name = name if name is not None else "kernel" + str(Kernel.last_kernel)
@@ -16,6 +17,7 @@ class Kernel(ASTNode):
         self._properties = {}
         self._bin_ops = []
         self._block = block
+        self._iterator = iterator
         sim.add_kernel(self)
         Kernel.last_kernel += 1
 
@@ -31,6 +33,10 @@ class Kernel(ASTNode):
     def block(self):
         return self._block
 
+    @property
+    def iterator(self):
+        return self._iterator
+
     def variables(self):
         return self._variables
 
diff --git a/src/pairs/transformations/__init__.py b/src/pairs/transformations/__init__.py
index 08cf0e01571d3c84b7cee475d7e4390da5266149..8615eb2d46cb98bce1404ba584c74d7ab3d51ac1 100644
--- a/src/pairs/transformations/__init__.py
+++ b/src/pairs/transformations/__init__.py
@@ -61,8 +61,8 @@ class Transformations:
 
     def add_device_kernels(self):
         if self._target.is_gpu():
-            self._analysis.fetch_kernel_references()
             self._add_device_kernels.mutate()
+            self._analysis.fetch_kernel_references()
 
     def apply_all(self):
         self.lower_everything()
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index fec459a24739bcdf52eb3a560a4f8410d6c4afc8..ba4ac4021523ff96733f4a2addf21ffa1e621f75 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -80,12 +80,12 @@ class AddDeviceKernels(Mutator):
             for s in ast_node._block.stmts:
                 if s is not None:
                     if isinstance(s, For) and (not isinstance(s.min, Lit) or not isinstance(s.max, Lit)):
-                        kernel = Kernel(ast_node.sim, f"{ast_node.name}_kernel{kernel_id}", s.block)
+                        kernel = Kernel(ast_node.sim, f"{ast_node.name}_kernel{kernel_id}", s.block, s.iterator)
                         new_stmts.append(KernelLaunch(ast_node.sim, kernel, s.iterator, s.min, s.max))
                         kernel_id += 1
                     else:
                         new_stmts.append(s)
 
-            ast_node._block_stmts = new_stmts
+            ast_node._block.stmts = new_stmts
 
         return ast_node