diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py
index 3746ac367a65aad321796e311bafc72062a0071e..c75b3a675f4ea79f301de6ffc0ef0570f30d5b69 100644
--- a/src/pairs/analysis/devices.py
+++ b/src/pairs/analysis/devices.py
@@ -86,3 +86,7 @@ class FetchKernelReferences(Visitor):
     def visit_Var(self, ast_node):
         for k in self.kernel_stack:
             k.add_variable(ast_node, self.writing)
+
+            # Variables only have a device version when changed within kernels
+            if self.writing:
+                ast_node.device_flag = True
diff --git a/src/pairs/analysis/modules.py b/src/pairs/analysis/modules.py
index c4a2c7aa6bfb4d806fbe5bc054e583875f067e56..8bf493048f0a587c2502a1d1ff784062735640dc 100644
--- a/src/pairs/analysis/modules.py
+++ b/src/pairs/analysis/modules.py
@@ -64,5 +64,3 @@ class FetchModulesReferences(Visitor):
         for m in self.module_stack:
             if not ast_node.temporary():
                 m.add_variable(ast_node, self.writing)
-                if m.run_on_device:
-                    ast_node.device_flag = True
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 984cdc91f832c8cf1ef3c9fd50ae319323eb201b..0eb4834b492440256e5712dec022854a72f5fe1b 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -6,7 +6,7 @@ from pairs.ir.branches import Branch
 from pairs.ir.cast import Cast
 from pairs.ir.contexts import Contexts
 from pairs.ir.bin_op import BinOp, Decl, VectorAccess
-from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, SetArrayFlag, SetPropertyFlag, HostRef
+from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, SetArrayFlag, SetPropertyFlag, HostRef
 from pairs.ir.functions import Call
 from pairs.ir.kernel import Kernel, KernelLaunch
 from pairs.ir.layouts import Layouts
@@ -296,6 +296,14 @@ class CGen:
             else:
                 self.print(f"pairs->copyPropertyToHost({prop_id}); // {prop_name}")
 
+        if isinstance(ast_node, CopyVar):
+            var_name = ast_node.variable.name()
+
+            if ast_node.context() == Contexts.Device:
+                self.print(f"rv_{var_name}->copyToDevice();")
+            else:
+                self.print(f"rv_{var_name}->copyToHost();")
+
         if isinstance(ast_node, ClearArrayFlag):
             array_id = ast_node.array.id()
             array_name = ast_node.array.name()
@@ -494,6 +502,10 @@ class CGen:
             tkw = Types.c_keyword(ast_node.var.type())
             self.print(f"{tkw} {ast_node.var.name()} = {ast_node.var.init_value()};")
 
+            if self.target.is_gpu() and ast_node.var.device_flag:
+                self.print(f"RuntimeVar *rv_{ast_node.var.name()} = pairs->addDeviceVariable(&({ast_node.var.name()}));")
+                #self.print(f"{tkw} *d_{ast_node.var.name()} = pairs->addDeviceVariable(&({ast_node.var.name()}));")
+
         if isinstance(ast_node, While):
             cond = self.generate_expression(ast_node.cond)
             self.print(f"while({cond}) {{")
diff --git a/src/pairs/ir/device.py b/src/pairs/ir/device.py
index 8a77d3aa488a1e13b4f45963c57423f6cc694724..cefa2a2c918b51e68acb2acd30c2b954a5e1a6c2 100644
--- a/src/pairs/ir/device.py
+++ b/src/pairs/ir/device.py
@@ -46,6 +46,20 @@ class CopyProperty(ASTNode):
         return [self.prop]
 
 
+class CopyVar(ASTNode):
+    def __init__(self, sim, variable, ctx):
+        super().__init__(sim)
+        self.variable = variable
+        self.ctx = ctx
+        self.sim.add_statement(self)
+
+    def context(self):
+        return self.ctx
+
+    def children(self):
+        return [self.variable]
+
+
 class ClearArrayFlag(ASTNode):
     def __init__(self, sim, array, ctx):
         super().__init__(sim)
diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py
index c960ff1657451b1bcae7a919a90ffdf5410439a4..548326132ecdc67ec7b83093c01d8c9b9d72c946 100644
--- a/src/pairs/ir/module.py
+++ b/src/pairs/ir/module.py
@@ -46,6 +46,9 @@ class Module(ASTNode):
     def variables(self):
         return self._variables
 
+    def variables_to_synchronize(self):
+        return {v for v in self._variables if 'w' in self._variables[v] and v.device_flag}
+
     def read_only_variables(self):
         return [v for v in self._variables if 'w' not in self._variables[v]]
 
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index 689a72c5ae1236866103f43dd80e64f85afe79b5..8590eba977ce394ea8a1064590d6296f4f7170f6 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -5,7 +5,7 @@ from pairs.ir.block import Block
 from pairs.ir.branches import Filter
 from pairs.ir.cast import Cast
 from pairs.ir.contexts import Contexts
-from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, SetArrayFlag, SetPropertyFlag, HostRef
+from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, SetArrayFlag, SetPropertyFlag, HostRef
 from pairs.ir.kernel import Kernel, KernelLaunch
 from pairs.ir.lit import Lit
 from pairs.ir.loops import For
@@ -50,10 +50,19 @@ class AddDeviceCopies(Mutator):
                     if self.module_resizes[s.module] and s.module.run_on_device:
                         new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Device)]
 
+                    if s.module.run_on_device:
+                        for v in s.module.variables_to_synchronize():
+                            new_stmts += [CopyVar(s.sim, v, Contexts.Device)]
+
                 new_stmts.append(s)
 
-                if isinstance(s, ModuleCall) and self.module_resizes[s.module] and s.module.run_on_device:
-                    new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Host)]
+                if isinstance(s, ModuleCall):
+                    if s.module.run_on_device:
+                        for v in s.module.variables_to_synchronize():
+                            new_stmts += [CopyVar(s.sim, v, Contexts.Host)]
+
+                        if self.module_resizes[s.module]:
+                            new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Host)]
 
         ast_node.stmts = new_stmts
         return ast_node