Skip to content
Snippets Groups Projects
Commit 934909fd authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Implement BinOp copy() method and use it to separate host/device versions

parent 2bfb17a2
Branches
No related tags found
No related merge requests found
......@@ -133,11 +133,11 @@ class ArrayAccess(ASTTerm):
ArrayAccess.last_acc += 1
return ArrayAccess.last_acc - 1
def __init__(self, sim, array, index):
def __init__(self, sim, array, indexes):
super().__init__(sim)
self.acc_id = ArrayAccess.new_id()
self.array = array
self.partial_indexes = [Lit.cvt(sim, index)]
self.partial_indexes = indexes if isinstance(indexes, list) else [Lit.cvt(sim, indexes)]
self.flat_index = None
self.inlined = False
self.terminals = set()
......@@ -156,6 +156,9 @@ class ArrayAccess(ASTTerm):
self.inlined = True
return self
def copy(self):
return ArrayAccess(self.sim, self.array, self.partial_indexes)
def check_and_set_flat_index(self):
if len(self.partial_indexes) == self.array.ndims():
sizes = self.array.sizes()
......
......@@ -54,6 +54,9 @@ class BinOp(VectorExpression):
b = self.rhs.id() if isinstance(self.rhs, BinOp) else self.rhs
return f"BinOp<{a} {self.op} {b}>"
def copy(self):
return BinOp(self.sim, self.lhs.copy(), self.rhs.copy(), self.op, self.mem)
def match(self, bin_op):
return self.lhs == bin_op.lhs and self.rhs == bin_op.rhs and self.op == bin_op.operator()
......
......@@ -9,7 +9,6 @@ class HostRef(ASTNode):
def __init__(self, sim, elem):
super().__init__(sim)
self.elem = elem
self.sim.add_statement(self)
def type(self):
return self.elem.type()
......
......@@ -43,5 +43,8 @@ class Lit(ASTNode):
def __req__(self, other):
return self.__cmp__(other)
def copy(self):
return Lit(self.sim, self.value)
def type(self):
return self.lit_type
......@@ -109,6 +109,9 @@ class PropertyAccess(ASTTerm, VectorExpression):
def __str__(self):
return f"PropertyAccess<{self.prop}, {self.index}>"
def copy(self):
return PropertyAccess(self.sim, self.prop, self.index)
def vector_index(self, v_index):
sizes = self.prop.sizes()
layout = self.prop.layout()
......
......@@ -50,6 +50,10 @@ class Var(ASTTerm):
def __str__(self):
return f"Var<{self.var_name}>"
def copy(self):
# Terminal copies are just themselves
return self
def set(self, other):
return self.sim.add_statement(Assign(self.sim, self, other))
......
......@@ -84,7 +84,7 @@ class AddDeviceKernels(Mutator):
kernel_name = f"{ast_node.name}_kernel{kernel_id}"
kernel = ast_node.sim.find_kernel_by_name(kernel_name)
if kernel is None:
kernel_body = Filter(ast_node.sim, BinOp.inline(s.iterator < s.max), s.block)
kernel_body = Filter(ast_node.sim, BinOp.inline(s.iterator < s.max.copy()), s.block)
kernel = Kernel(ast_node.sim, kernel_name, kernel_body, s.iterator)
kernel_id += 1
......@@ -144,6 +144,8 @@ class AddHostReferencesToModules(Mutator):
return ast_node
def mutate_KernelLaunch(self, ast_node):
ast_node._threads_per_block = self.mutate(ast_node._threads_per_block)
ast_node._nblocks = self.mutate(ast_node._nblocks)
return ast_node
def mutate_Property(self, ast_node):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment