From 934909fd178f037a5d894422e569526b80f24086 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Tue, 22 Nov 2022 17:25:18 +0100
Subject: [PATCH] Implement BinOp copy() method and use it to separate
 host/device versions

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 src/pairs/ir/arrays.py               | 7 +++++--
 src/pairs/ir/bin_op.py               | 3 +++
 src/pairs/ir/device.py               | 1 -
 src/pairs/ir/lit.py                  | 3 +++
 src/pairs/ir/properties.py           | 3 +++
 src/pairs/ir/variables.py            | 4 ++++
 src/pairs/transformations/devices.py | 4 +++-
 7 files changed, 21 insertions(+), 4 deletions(-)

diff --git a/src/pairs/ir/arrays.py b/src/pairs/ir/arrays.py
index 61734d1..814cfcb 100644
--- a/src/pairs/ir/arrays.py
+++ b/src/pairs/ir/arrays.py
@@ -133,11 +133,11 @@ class ArrayAccess(ASTTerm):
         ArrayAccess.last_acc += 1
         return ArrayAccess.last_acc - 1
 
-    def __init__(self, sim, array, index):
+    def __init__(self, sim, array, indexes):
         super().__init__(sim)
         self.acc_id = ArrayAccess.new_id()
         self.array = array
-        self.partial_indexes = [Lit.cvt(sim, index)]
+        self.partial_indexes = indexes if isinstance(indexes, list) else [Lit.cvt(sim, indexes)]
         self.flat_index = None
         self.inlined = False
         self.terminals = set()
@@ -156,6 +156,9 @@ class ArrayAccess(ASTTerm):
         self.inlined = True
         return self
 
+    def copy(self):
+        return ArrayAccess(self.sim, self.array, self.partial_indexes)
+
     def check_and_set_flat_index(self):
         if len(self.partial_indexes) == self.array.ndims():
             sizes = self.array.sizes()
diff --git a/src/pairs/ir/bin_op.py b/src/pairs/ir/bin_op.py
index 779e9f5..f098e0b 100644
--- a/src/pairs/ir/bin_op.py
+++ b/src/pairs/ir/bin_op.py
@@ -54,6 +54,9 @@ class BinOp(VectorExpression):
         b = self.rhs.id() if isinstance(self.rhs, BinOp) else self.rhs
         return f"BinOp<{a} {self.op} {b}>"
 
+    def copy(self):
+        return BinOp(self.sim, self.lhs.copy(), self.rhs.copy(), self.op, self.mem)
+
     def match(self, bin_op):
         return self.lhs == bin_op.lhs and self.rhs == bin_op.rhs and self.op == bin_op.operator()
 
diff --git a/src/pairs/ir/device.py b/src/pairs/ir/device.py
index cefa2a2..6ba75d8 100644
--- a/src/pairs/ir/device.py
+++ b/src/pairs/ir/device.py
@@ -9,7 +9,6 @@ class HostRef(ASTNode):
     def __init__(self, sim, elem):
         super().__init__(sim)
         self.elem = elem
-        self.sim.add_statement(self)
 
     def type(self):
         return self.elem.type()
diff --git a/src/pairs/ir/lit.py b/src/pairs/ir/lit.py
index ce829b3..d23367c 100644
--- a/src/pairs/ir/lit.py
+++ b/src/pairs/ir/lit.py
@@ -43,5 +43,8 @@ class Lit(ASTNode):
     def __req__(self, other):
         return self.__cmp__(other)
 
+    def copy(self):
+        return Lit(self.sim, self.value)
+
     def type(self):
         return self.lit_type
diff --git a/src/pairs/ir/properties.py b/src/pairs/ir/properties.py
index c77a185..0bf42df 100644
--- a/src/pairs/ir/properties.py
+++ b/src/pairs/ir/properties.py
@@ -109,6 +109,9 @@ class PropertyAccess(ASTTerm, VectorExpression):
     def __str__(self):
         return f"PropertyAccess<{self.prop}, {self.index}>"
 
+    def copy(self):
+        return PropertyAccess(self.sim, self.prop, self.index)
+
     def vector_index(self, v_index):
         sizes = self.prop.sizes()
         layout = self.prop.layout()
diff --git a/src/pairs/ir/variables.py b/src/pairs/ir/variables.py
index 4b790e9..38a1944 100644
--- a/src/pairs/ir/variables.py
+++ b/src/pairs/ir/variables.py
@@ -50,6 +50,10 @@ class Var(ASTTerm):
     def __str__(self):
         return f"Var<{self.var_name}>"
 
+    def copy(self):
+        # Terminal copies are just themselves
+        return self
+
     def set(self, other):
         return self.sim.add_statement(Assign(self.sim, self, other))
 
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index 8590eba..8094b69 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -84,7 +84,7 @@ class AddDeviceKernels(Mutator):
                         kernel_name = f"{ast_node.name}_kernel{kernel_id}"
                         kernel = ast_node.sim.find_kernel_by_name(kernel_name)
                         if kernel is None:
-                            kernel_body = Filter(ast_node.sim, BinOp.inline(s.iterator < s.max), s.block)
+                            kernel_body = Filter(ast_node.sim, BinOp.inline(s.iterator < s.max.copy()), s.block)
                             kernel = Kernel(ast_node.sim, kernel_name, kernel_body, s.iterator)
                             kernel_id += 1
 
@@ -144,6 +144,8 @@ class AddHostReferencesToModules(Mutator):
         return ast_node
 
     def mutate_KernelLaunch(self, ast_node):
+        ast_node._threads_per_block = self.mutate(ast_node._threads_per_block)
+        ast_node._nblocks = self.mutate(ast_node._nblocks)
         return ast_node
 
     def mutate_Property(self, ast_node):
-- 
GitLab