diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py index 3746ac367a65aad321796e311bafc72062a0071e..c75b3a675f4ea79f301de6ffc0ef0570f30d5b69 100644 --- a/src/pairs/analysis/devices.py +++ b/src/pairs/analysis/devices.py @@ -86,3 +86,7 @@ class FetchKernelReferences(Visitor): def visit_Var(self, ast_node): for k in self.kernel_stack: k.add_variable(ast_node, self.writing) + + # Variables only have a device version when changed within kernels + if self.writing: + ast_node.device_flag = True diff --git a/src/pairs/analysis/modules.py b/src/pairs/analysis/modules.py index c4a2c7aa6bfb4d806fbe5bc054e583875f067e56..8bf493048f0a587c2502a1d1ff784062735640dc 100644 --- a/src/pairs/analysis/modules.py +++ b/src/pairs/analysis/modules.py @@ -64,5 +64,3 @@ class FetchModulesReferences(Visitor): for m in self.module_stack: if not ast_node.temporary(): m.add_variable(ast_node, self.writing) - if m.run_on_device: - ast_node.device_flag = True diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 984cdc91f832c8cf1ef3c9fd50ae319323eb201b..0eb4834b492440256e5712dec022854a72f5fe1b 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -6,7 +6,7 @@ from pairs.ir.branches import Branch from pairs.ir.cast import Cast from pairs.ir.contexts import Contexts from pairs.ir.bin_op import BinOp, Decl, VectorAccess -from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, SetArrayFlag, SetPropertyFlag, HostRef +from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, SetArrayFlag, SetPropertyFlag, HostRef from pairs.ir.functions import Call from pairs.ir.kernel import Kernel, KernelLaunch from pairs.ir.layouts import Layouts @@ -296,6 +296,14 @@ class CGen: else: self.print(f"pairs->copyPropertyToHost({prop_id}); // {prop_name}") + if isinstance(ast_node, CopyVar): + var_name = ast_node.variable.name() + + if ast_node.context() == Contexts.Device: + self.print(f"rv_{var_name}->copyToDevice();") + else: + self.print(f"rv_{var_name}->copyToHost();") + if isinstance(ast_node, ClearArrayFlag): array_id = ast_node.array.id() array_name = ast_node.array.name() @@ -494,6 +502,10 @@ class CGen: tkw = Types.c_keyword(ast_node.var.type()) self.print(f"{tkw} {ast_node.var.name()} = {ast_node.var.init_value()};") + if self.target.is_gpu() and ast_node.var.device_flag: + self.print(f"RuntimeVar *rv_{ast_node.var.name()} = pairs->addDeviceVariable(&({ast_node.var.name()}));") + #self.print(f"{tkw} *d_{ast_node.var.name()} = pairs->addDeviceVariable(&({ast_node.var.name()}));") + if isinstance(ast_node, While): cond = self.generate_expression(ast_node.cond) self.print(f"while({cond}) {{") diff --git a/src/pairs/ir/device.py b/src/pairs/ir/device.py index 8a77d3aa488a1e13b4f45963c57423f6cc694724..cefa2a2c918b51e68acb2acd30c2b954a5e1a6c2 100644 --- a/src/pairs/ir/device.py +++ b/src/pairs/ir/device.py @@ -46,6 +46,20 @@ class CopyProperty(ASTNode): return [self.prop] +class CopyVar(ASTNode): + def __init__(self, sim, variable, ctx): + super().__init__(sim) + self.variable = variable + self.ctx = ctx + self.sim.add_statement(self) + + def context(self): + return self.ctx + + def children(self): + return [self.variable] + + class ClearArrayFlag(ASTNode): def __init__(self, sim, array, ctx): super().__init__(sim) diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py index c960ff1657451b1bcae7a919a90ffdf5410439a4..548326132ecdc67ec7b83093c01d8c9b9d72c946 100644 --- a/src/pairs/ir/module.py +++ b/src/pairs/ir/module.py @@ -46,6 +46,9 @@ class Module(ASTNode): def variables(self): return self._variables + def variables_to_synchronize(self): + return {v for v in self._variables if 'w' in self._variables[v] and v.device_flag} + def read_only_variables(self): return [v for v in self._variables if 'w' not in self._variables[v]] diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py index 689a72c5ae1236866103f43dd80e64f85afe79b5..8590eba977ce394ea8a1064590d6296f4f7170f6 100644 --- a/src/pairs/transformations/devices.py +++ b/src/pairs/transformations/devices.py @@ -5,7 +5,7 @@ from pairs.ir.block import Block from pairs.ir.branches import Filter from pairs.ir.cast import Cast from pairs.ir.contexts import Contexts -from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, SetArrayFlag, SetPropertyFlag, HostRef +from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, SetArrayFlag, SetPropertyFlag, HostRef from pairs.ir.kernel import Kernel, KernelLaunch from pairs.ir.lit import Lit from pairs.ir.loops import For @@ -50,10 +50,19 @@ class AddDeviceCopies(Mutator): if self.module_resizes[s.module] and s.module.run_on_device: new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Device)] + if s.module.run_on_device: + for v in s.module.variables_to_synchronize(): + new_stmts += [CopyVar(s.sim, v, Contexts.Device)] + new_stmts.append(s) - if isinstance(s, ModuleCall) and self.module_resizes[s.module] and s.module.run_on_device: - new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Host)] + if isinstance(s, ModuleCall): + if s.module.run_on_device: + for v in s.module.variables_to_synchronize(): + new_stmts += [CopyVar(s.sim, v, Contexts.Host)] + + if self.module_resizes[s.module]: + new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Host)] ast_node.stmts = new_stmts return ast_node