diff --git a/src/pairs/ir/arrays.py b/src/pairs/ir/arrays.py index 61734d1a914022471b473ef61afe4c735ead08bd..814cfcbcdaf4466018e3cf491d72b6588a5346b2 100644 --- a/src/pairs/ir/arrays.py +++ b/src/pairs/ir/arrays.py @@ -133,11 +133,11 @@ class ArrayAccess(ASTTerm): ArrayAccess.last_acc += 1 return ArrayAccess.last_acc - 1 - def __init__(self, sim, array, index): + def __init__(self, sim, array, indexes): super().__init__(sim) self.acc_id = ArrayAccess.new_id() self.array = array - self.partial_indexes = [Lit.cvt(sim, index)] + self.partial_indexes = indexes if isinstance(indexes, list) else [Lit.cvt(sim, indexes)] self.flat_index = None self.inlined = False self.terminals = set() @@ -156,6 +156,9 @@ class ArrayAccess(ASTTerm): self.inlined = True return self + def copy(self): + return ArrayAccess(self.sim, self.array, self.partial_indexes) + def check_and_set_flat_index(self): if len(self.partial_indexes) == self.array.ndims(): sizes = self.array.sizes() diff --git a/src/pairs/ir/bin_op.py b/src/pairs/ir/bin_op.py index 779e9f5669af236b03ef82d866ed6417164437c0..f098e0b6c4cf9d86d98d79578c50e30ecedc1aae 100644 --- a/src/pairs/ir/bin_op.py +++ b/src/pairs/ir/bin_op.py @@ -54,6 +54,9 @@ class BinOp(VectorExpression): b = self.rhs.id() if isinstance(self.rhs, BinOp) else self.rhs return f"BinOp<{a} {self.op} {b}>" + def copy(self): + return BinOp(self.sim, self.lhs.copy(), self.rhs.copy(), self.op, self.mem) + def match(self, bin_op): return self.lhs == bin_op.lhs and self.rhs == bin_op.rhs and self.op == bin_op.operator() diff --git a/src/pairs/ir/device.py b/src/pairs/ir/device.py index cefa2a2c918b51e68acb2acd30c2b954a5e1a6c2..6ba75d860ba09286a97ff9c35bfa52a87f9734bd 100644 --- a/src/pairs/ir/device.py +++ b/src/pairs/ir/device.py @@ -9,7 +9,6 @@ class HostRef(ASTNode): def __init__(self, sim, elem): super().__init__(sim) self.elem = elem - self.sim.add_statement(self) def type(self): return self.elem.type() diff --git a/src/pairs/ir/lit.py b/src/pairs/ir/lit.py index ce829b3f4a1346ace8a3b33f09b73cfb0a88212a..d23367c82fd5218f1646c33eed9bf827b993e9c3 100644 --- a/src/pairs/ir/lit.py +++ b/src/pairs/ir/lit.py @@ -43,5 +43,8 @@ class Lit(ASTNode): def __req__(self, other): return self.__cmp__(other) + def copy(self): + return Lit(self.sim, self.value) + def type(self): return self.lit_type diff --git a/src/pairs/ir/properties.py b/src/pairs/ir/properties.py index c77a1859719f6bd7599fe4bbef561d0d25a4582c..0bf42dffd4c4b35bf54bf9c47cc4550226c54602 100644 --- a/src/pairs/ir/properties.py +++ b/src/pairs/ir/properties.py @@ -109,6 +109,9 @@ class PropertyAccess(ASTTerm, VectorExpression): def __str__(self): return f"PropertyAccess<{self.prop}, {self.index}>" + def copy(self): + return PropertyAccess(self.sim, self.prop, self.index) + def vector_index(self, v_index): sizes = self.prop.sizes() layout = self.prop.layout() diff --git a/src/pairs/ir/variables.py b/src/pairs/ir/variables.py index 4b790e9b7c82fa5c8341e18c7a6e37896da97561..38a194495d7222b74e454f66647b8ca14b675b65 100644 --- a/src/pairs/ir/variables.py +++ b/src/pairs/ir/variables.py @@ -50,6 +50,10 @@ class Var(ASTTerm): def __str__(self): return f"Var<{self.var_name}>" + def copy(self): + # Terminal copies are just themselves + return self + def set(self, other): return self.sim.add_statement(Assign(self.sim, self, other)) diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py index 8590eba977ce394ea8a1064590d6296f4f7170f6..8094b690a984bb23e5b2ed0b4baeb54ef1a96b19 100644 --- a/src/pairs/transformations/devices.py +++ b/src/pairs/transformations/devices.py @@ -84,7 +84,7 @@ class AddDeviceKernels(Mutator): kernel_name = f"{ast_node.name}_kernel{kernel_id}" kernel = ast_node.sim.find_kernel_by_name(kernel_name) if kernel is None: - kernel_body = Filter(ast_node.sim, BinOp.inline(s.iterator < s.max), s.block) + kernel_body = Filter(ast_node.sim, BinOp.inline(s.iterator < s.max.copy()), s.block) kernel = Kernel(ast_node.sim, kernel_name, kernel_body, s.iterator) kernel_id += 1 @@ -144,6 +144,8 @@ class AddHostReferencesToModules(Mutator): return ast_node def mutate_KernelLaunch(self, ast_node): + ast_node._threads_per_block = self.mutate(ast_node._threads_per_block) + ast_node._nblocks = self.mutate(ast_node._nblocks) return ast_node def mutate_Property(self, ast_node):