From f435c3c433719ea357f02a7ede9d5f10cc2ef186 Mon Sep 17 00:00:00 2001 From: Rafael Ravedutti <rafaelravedutti@gmail.com> Date: Thu, 3 Nov 2022 01:43:37 +0100 Subject: [PATCH] Generate code for variables changed within kernels Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com> --- src/pairs/analysis/devices.py | 4 ++++ src/pairs/analysis/modules.py | 2 -- src/pairs/code_gen/cgen.py | 14 +++++++++++++- src/pairs/ir/device.py | 14 ++++++++++++++ src/pairs/ir/module.py | 3 +++ src/pairs/transformations/devices.py | 15 ++++++++++++--- 6 files changed, 46 insertions(+), 6 deletions(-) diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py index 3746ac3..c75b3a6 100644 --- a/src/pairs/analysis/devices.py +++ b/src/pairs/analysis/devices.py @@ -86,3 +86,7 @@ class FetchKernelReferences(Visitor): def visit_Var(self, ast_node): for k in self.kernel_stack: k.add_variable(ast_node, self.writing) + + # Variables only have a device version when changed within kernels + if self.writing: + ast_node.device_flag = True diff --git a/src/pairs/analysis/modules.py b/src/pairs/analysis/modules.py index c4a2c7a..8bf4930 100644 --- a/src/pairs/analysis/modules.py +++ b/src/pairs/analysis/modules.py @@ -64,5 +64,3 @@ class FetchModulesReferences(Visitor): for m in self.module_stack: if not ast_node.temporary(): m.add_variable(ast_node, self.writing) - if m.run_on_device: - ast_node.device_flag = True diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 984cdc9..0eb4834 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -6,7 +6,7 @@ from pairs.ir.branches import Branch from pairs.ir.cast import Cast from pairs.ir.contexts import Contexts from pairs.ir.bin_op import BinOp, Decl, VectorAccess -from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, SetArrayFlag, SetPropertyFlag, HostRef +from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, SetArrayFlag, SetPropertyFlag, HostRef from pairs.ir.functions import Call from pairs.ir.kernel import Kernel, KernelLaunch from pairs.ir.layouts import Layouts @@ -296,6 +296,14 @@ class CGen: else: self.print(f"pairs->copyPropertyToHost({prop_id}); // {prop_name}") + if isinstance(ast_node, CopyVar): + var_name = ast_node.variable.name() + + if ast_node.context() == Contexts.Device: + self.print(f"rv_{var_name}->copyToDevice();") + else: + self.print(f"rv_{var_name}->copyToHost();") + if isinstance(ast_node, ClearArrayFlag): array_id = ast_node.array.id() array_name = ast_node.array.name() @@ -494,6 +502,10 @@ class CGen: tkw = Types.c_keyword(ast_node.var.type()) self.print(f"{tkw} {ast_node.var.name()} = {ast_node.var.init_value()};") + if self.target.is_gpu() and ast_node.var.device_flag: + self.print(f"RuntimeVar *rv_{ast_node.var.name()} = pairs->addDeviceVariable(&({ast_node.var.name()}));") + #self.print(f"{tkw} *d_{ast_node.var.name()} = pairs->addDeviceVariable(&({ast_node.var.name()}));") + if isinstance(ast_node, While): cond = self.generate_expression(ast_node.cond) self.print(f"while({cond}) {{") diff --git a/src/pairs/ir/device.py b/src/pairs/ir/device.py index 8a77d3a..cefa2a2 100644 --- a/src/pairs/ir/device.py +++ b/src/pairs/ir/device.py @@ -46,6 +46,20 @@ class CopyProperty(ASTNode): return [self.prop] +class CopyVar(ASTNode): + def __init__(self, sim, variable, ctx): + super().__init__(sim) + self.variable = variable + self.ctx = ctx + self.sim.add_statement(self) + + def context(self): + return self.ctx + + def children(self): + return [self.variable] + + class ClearArrayFlag(ASTNode): def __init__(self, sim, array, ctx): super().__init__(sim) diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py index c960ff1..5483261 100644 --- a/src/pairs/ir/module.py +++ b/src/pairs/ir/module.py @@ -46,6 +46,9 @@ class Module(ASTNode): def variables(self): return self._variables + def variables_to_synchronize(self): + return {v for v in self._variables if 'w' in self._variables[v] and v.device_flag} + def read_only_variables(self): return [v for v in self._variables if 'w' not in self._variables[v]] diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py index 689a72c..8590eba 100644 --- a/src/pairs/transformations/devices.py +++ b/src/pairs/transformations/devices.py @@ -5,7 +5,7 @@ from pairs.ir.block import Block from pairs.ir.branches import Filter from pairs.ir.cast import Cast from pairs.ir.contexts import Contexts -from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, SetArrayFlag, SetPropertyFlag, HostRef +from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, SetArrayFlag, SetPropertyFlag, HostRef from pairs.ir.kernel import Kernel, KernelLaunch from pairs.ir.lit import Lit from pairs.ir.loops import For @@ -50,10 +50,19 @@ class AddDeviceCopies(Mutator): if self.module_resizes[s.module] and s.module.run_on_device: new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Device)] + if s.module.run_on_device: + for v in s.module.variables_to_synchronize(): + new_stmts += [CopyVar(s.sim, v, Contexts.Device)] + new_stmts.append(s) - if isinstance(s, ModuleCall) and self.module_resizes[s.module] and s.module.run_on_device: - new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Host)] + if isinstance(s, ModuleCall): + if s.module.run_on_device: + for v in s.module.variables_to_synchronize(): + new_stmts += [CopyVar(s.sim, v, Contexts.Host)] + + if self.module_resizes[s.module]: + new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Host)] ast_node.stmts = new_stmts return ast_node -- GitLab