From 934909fd178f037a5d894422e569526b80f24086 Mon Sep 17 00:00:00 2001 From: Rafael Ravedutti <rafaelravedutti@gmail.com> Date: Tue, 22 Nov 2022 17:25:18 +0100 Subject: [PATCH] Implement BinOp copy() method and use it to separate host/device versions Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com> --- src/pairs/ir/arrays.py | 7 +++++-- src/pairs/ir/bin_op.py | 3 +++ src/pairs/ir/device.py | 1 - src/pairs/ir/lit.py | 3 +++ src/pairs/ir/properties.py | 3 +++ src/pairs/ir/variables.py | 4 ++++ src/pairs/transformations/devices.py | 4 +++- 7 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/pairs/ir/arrays.py b/src/pairs/ir/arrays.py index 61734d1..814cfcb 100644 --- a/src/pairs/ir/arrays.py +++ b/src/pairs/ir/arrays.py @@ -133,11 +133,11 @@ class ArrayAccess(ASTTerm): ArrayAccess.last_acc += 1 return ArrayAccess.last_acc - 1 - def __init__(self, sim, array, index): + def __init__(self, sim, array, indexes): super().__init__(sim) self.acc_id = ArrayAccess.new_id() self.array = array - self.partial_indexes = [Lit.cvt(sim, index)] + self.partial_indexes = indexes if isinstance(indexes, list) else [Lit.cvt(sim, indexes)] self.flat_index = None self.inlined = False self.terminals = set() @@ -156,6 +156,9 @@ class ArrayAccess(ASTTerm): self.inlined = True return self + def copy(self): + return ArrayAccess(self.sim, self.array, self.partial_indexes) + def check_and_set_flat_index(self): if len(self.partial_indexes) == self.array.ndims(): sizes = self.array.sizes() diff --git a/src/pairs/ir/bin_op.py b/src/pairs/ir/bin_op.py index 779e9f5..f098e0b 100644 --- a/src/pairs/ir/bin_op.py +++ b/src/pairs/ir/bin_op.py @@ -54,6 +54,9 @@ class BinOp(VectorExpression): b = self.rhs.id() if isinstance(self.rhs, BinOp) else self.rhs return f"BinOp<{a} {self.op} {b}>" + def copy(self): + return BinOp(self.sim, self.lhs.copy(), self.rhs.copy(), self.op, self.mem) + def match(self, bin_op): return self.lhs == bin_op.lhs and self.rhs == bin_op.rhs and self.op == bin_op.operator() diff --git a/src/pairs/ir/device.py b/src/pairs/ir/device.py index cefa2a2..6ba75d8 100644 --- a/src/pairs/ir/device.py +++ b/src/pairs/ir/device.py @@ -9,7 +9,6 @@ class HostRef(ASTNode): def __init__(self, sim, elem): super().__init__(sim) self.elem = elem - self.sim.add_statement(self) def type(self): return self.elem.type() diff --git a/src/pairs/ir/lit.py b/src/pairs/ir/lit.py index ce829b3..d23367c 100644 --- a/src/pairs/ir/lit.py +++ b/src/pairs/ir/lit.py @@ -43,5 +43,8 @@ class Lit(ASTNode): def __req__(self, other): return self.__cmp__(other) + def copy(self): + return Lit(self.sim, self.value) + def type(self): return self.lit_type diff --git a/src/pairs/ir/properties.py b/src/pairs/ir/properties.py index c77a185..0bf42df 100644 --- a/src/pairs/ir/properties.py +++ b/src/pairs/ir/properties.py @@ -109,6 +109,9 @@ class PropertyAccess(ASTTerm, VectorExpression): def __str__(self): return f"PropertyAccess<{self.prop}, {self.index}>" + def copy(self): + return PropertyAccess(self.sim, self.prop, self.index) + def vector_index(self, v_index): sizes = self.prop.sizes() layout = self.prop.layout() diff --git a/src/pairs/ir/variables.py b/src/pairs/ir/variables.py index 4b790e9..38a1944 100644 --- a/src/pairs/ir/variables.py +++ b/src/pairs/ir/variables.py @@ -50,6 +50,10 @@ class Var(ASTTerm): def __str__(self): return f"Var<{self.var_name}>" + def copy(self): + # Terminal copies are just themselves + return self + def set(self, other): return self.sim.add_statement(Assign(self.sim, self, other)) diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py index 8590eba..8094b69 100644 --- a/src/pairs/transformations/devices.py +++ b/src/pairs/transformations/devices.py @@ -84,7 +84,7 @@ class AddDeviceKernels(Mutator): kernel_name = f"{ast_node.name}_kernel{kernel_id}" kernel = ast_node.sim.find_kernel_by_name(kernel_name) if kernel is None: - kernel_body = Filter(ast_node.sim, BinOp.inline(s.iterator < s.max), s.block) + kernel_body = Filter(ast_node.sim, BinOp.inline(s.iterator < s.max.copy()), s.block) kernel = Kernel(ast_node.sim, kernel_name, kernel_body, s.iterator) kernel_id += 1 @@ -144,6 +144,8 @@ class AddHostReferencesToModules(Mutator): return ast_node def mutate_KernelLaunch(self, ast_node): + ast_node._threads_per_block = self.mutate(ast_node._threads_per_block) + ast_node._nblocks = self.mutate(ast_node._nblocks) return ast_node def mutate_Property(self, ast_node): -- GitLab