diff --git a/src/pairs/ir/arrays.py b/src/pairs/ir/arrays.py
index 61734d1a914022471b473ef61afe4c735ead08bd..814cfcbcdaf4466018e3cf491d72b6588a5346b2 100644
--- a/src/pairs/ir/arrays.py
+++ b/src/pairs/ir/arrays.py
@@ -133,11 +133,11 @@ class ArrayAccess(ASTTerm):
         ArrayAccess.last_acc += 1
         return ArrayAccess.last_acc - 1
 
-    def __init__(self, sim, array, index):
+    def __init__(self, sim, array, indexes):
         super().__init__(sim)
         self.acc_id = ArrayAccess.new_id()
         self.array = array
-        self.partial_indexes = [Lit.cvt(sim, index)]
+        self.partial_indexes = indexes if isinstance(indexes, list) else [Lit.cvt(sim, indexes)]
         self.flat_index = None
         self.inlined = False
         self.terminals = set()
@@ -156,6 +156,9 @@ class ArrayAccess(ASTTerm):
         self.inlined = True
         return self
 
+    def copy(self):
+        return ArrayAccess(self.sim, self.array, self.partial_indexes)
+
     def check_and_set_flat_index(self):
         if len(self.partial_indexes) == self.array.ndims():
             sizes = self.array.sizes()
diff --git a/src/pairs/ir/bin_op.py b/src/pairs/ir/bin_op.py
index 779e9f5669af236b03ef82d866ed6417164437c0..f098e0b6c4cf9d86d98d79578c50e30ecedc1aae 100644
--- a/src/pairs/ir/bin_op.py
+++ b/src/pairs/ir/bin_op.py
@@ -54,6 +54,9 @@ class BinOp(VectorExpression):
         b = self.rhs.id() if isinstance(self.rhs, BinOp) else self.rhs
         return f"BinOp<{a} {self.op} {b}>"
 
+    def copy(self):
+        return BinOp(self.sim, self.lhs.copy(), self.rhs.copy(), self.op, self.mem)
+
     def match(self, bin_op):
         return self.lhs == bin_op.lhs and self.rhs == bin_op.rhs and self.op == bin_op.operator()
 
diff --git a/src/pairs/ir/device.py b/src/pairs/ir/device.py
index cefa2a2c918b51e68acb2acd30c2b954a5e1a6c2..6ba75d860ba09286a97ff9c35bfa52a87f9734bd 100644
--- a/src/pairs/ir/device.py
+++ b/src/pairs/ir/device.py
@@ -9,7 +9,6 @@ class HostRef(ASTNode):
     def __init__(self, sim, elem):
         super().__init__(sim)
         self.elem = elem
-        self.sim.add_statement(self)
 
     def type(self):
         return self.elem.type()
diff --git a/src/pairs/ir/lit.py b/src/pairs/ir/lit.py
index ce829b3f4a1346ace8a3b33f09b73cfb0a88212a..d23367c82fd5218f1646c33eed9bf827b993e9c3 100644
--- a/src/pairs/ir/lit.py
+++ b/src/pairs/ir/lit.py
@@ -43,5 +43,8 @@ class Lit(ASTNode):
     def __req__(self, other):
         return self.__cmp__(other)
 
+    def copy(self):
+        return Lit(self.sim, self.value)
+
     def type(self):
         return self.lit_type
diff --git a/src/pairs/ir/properties.py b/src/pairs/ir/properties.py
index c77a1859719f6bd7599fe4bbef561d0d25a4582c..0bf42dffd4c4b35bf54bf9c47cc4550226c54602 100644
--- a/src/pairs/ir/properties.py
+++ b/src/pairs/ir/properties.py
@@ -109,6 +109,9 @@ class PropertyAccess(ASTTerm, VectorExpression):
     def __str__(self):
         return f"PropertyAccess<{self.prop}, {self.index}>"
 
+    def copy(self):
+        return PropertyAccess(self.sim, self.prop, self.index)
+
     def vector_index(self, v_index):
         sizes = self.prop.sizes()
         layout = self.prop.layout()
diff --git a/src/pairs/ir/variables.py b/src/pairs/ir/variables.py
index 4b790e9b7c82fa5c8341e18c7a6e37896da97561..38a194495d7222b74e454f66647b8ca14b675b65 100644
--- a/src/pairs/ir/variables.py
+++ b/src/pairs/ir/variables.py
@@ -50,6 +50,10 @@ class Var(ASTTerm):
     def __str__(self):
         return f"Var<{self.var_name}>"
 
+    def copy(self):
+        # Terminal copies are just themselves
+        return self
+
     def set(self, other):
         return self.sim.add_statement(Assign(self.sim, self, other))
 
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index 8590eba977ce394ea8a1064590d6296f4f7170f6..8094b690a984bb23e5b2ed0b4baeb54ef1a96b19 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -84,7 +84,7 @@ class AddDeviceKernels(Mutator):
                         kernel_name = f"{ast_node.name}_kernel{kernel_id}"
                         kernel = ast_node.sim.find_kernel_by_name(kernel_name)
                         if kernel is None:
-                            kernel_body = Filter(ast_node.sim, BinOp.inline(s.iterator < s.max), s.block)
+                            kernel_body = Filter(ast_node.sim, BinOp.inline(s.iterator < s.max.copy()), s.block)
                             kernel = Kernel(ast_node.sim, kernel_name, kernel_body, s.iterator)
                             kernel_id += 1
 
@@ -144,6 +144,8 @@ class AddHostReferencesToModules(Mutator):
         return ast_node
 
     def mutate_KernelLaunch(self, ast_node):
+        ast_node._threads_per_block = self.mutate(ast_node._threads_per_block)
+        ast_node._nblocks = self.mutate(ast_node._nblocks)
         return ast_node
 
     def mutate_Property(self, ast_node):