From 0a858643e1f55d7c7199820eb84080b6598008e5 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Thu, 31 Mar 2022 01:21:42 +0200
Subject: [PATCH] Perform transformations in all modules and kernels

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 src/pairs/ir/kernel.py               |  5 ++++-
 src/pairs/ir/module.py               |  3 +++
 src/pairs/ir/mutator.py              | 15 +++++++++++++++
 src/pairs/sim/simulation.py          |  1 +
 src/pairs/transformations/devices.py |  3 +++
 5 files changed, 26 insertions(+), 1 deletion(-)

diff --git a/src/pairs/ir/kernel.py b/src/pairs/ir/kernel.py
index f1040ec..b1cdf0f 100644
--- a/src/pairs/ir/kernel.py
+++ b/src/pairs/ir/kernel.py
@@ -88,7 +88,7 @@ class Kernel(ASTNode):
 
 class KernelLaunch(ASTNode):
     def __init__(self, sim, kernel, iterator, range_min, range_max):
-        assert isinstance(module, Kernel), "KernelLaunch(): given parameter is not of type Kernel!"
+        assert isinstance(kernel, Kernel), "KernelLaunch(): given parameter is not of type Kernel!"
         super().__init__(sim)
         self._kernel = kernel
         self._iterator = iterator
@@ -111,3 +111,6 @@ class KernelLaunch(ASTNode):
     @property
     def threads_per_block(self):
         return self._threads_per_block
+
+    def children(self):
+        return [self._kernel, self._iterator, self._range_min, self._range_max]
diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py
index 7353ddd..da2f225 100644
--- a/src/pairs/ir/module.py
+++ b/src/pairs/ir/module.py
@@ -87,3 +87,6 @@ class ModuleCall(ASTNode):
     @property
     def module(self):
         return self._module
+
+    def children(self):
+        return [self._module]
diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py
index 7118b07..e50e209 100644
--- a/src/pairs/ir/mutator.py
+++ b/src/pairs/ir/mutator.py
@@ -66,6 +66,17 @@ class Mutator:
         ast_node.block = self.mutate(ast_node.block)
         return ast_node
 
+    def mutate_Kernel(self, ast_node):
+        ast_node._block = self.mutate(ast_node._block)
+        return ast_node
+
+    def mutate_KernelLaunch(self, ast_node):
+        ast_node._kernel = self.mutate(ast_node._kernel)
+        ast_node._iterator = self.mutate(ast_node._iterator)
+        ast_node._range_min = self.mutate(ast_node._range_min)
+        ast_node._range_max = self.mutate(ast_node._range_max)
+        return ast_node
+
     def mutate_ParticleFor(self, ast_node):
         return self.mutate_For(ast_node)
 
@@ -84,6 +95,10 @@ class Mutator:
         ast_node._block = self.mutate(ast_node._block)
         return ast_node
 
+    def mutate_ModuleCall(self, ast_node):
+        ast_node._module = self.mutate(ast_node._module)
+        return ast_node
+
     def mutate_Realloc(self, ast_node):
         ast_node.array = self.mutate(ast_node.array)
         ast_node.size = self.mutate(ast_node.size)
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index e8bd7b2..583bbb3 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -1,6 +1,7 @@
 from pairs.ir.arrays import Arrays
 from pairs.ir.block import Block
 from pairs.ir.branches import Filter
+from pairs.ir.kernel import Kernel
 from pairs.ir.layouts import Layouts
 from pairs.ir.module import Module
 from pairs.ir.properties import Properties
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index e493d94..fec459a 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -4,6 +4,9 @@ from pairs.ir.bin_op import BinOp
 from pairs.ir.block import Block
 from pairs.ir.branches import Filter
 from pairs.ir.device import CopyToDevice, CopyToHost
+from pairs.ir.kernel import Kernel, KernelLaunch
+from pairs.ir.lit import Lit
+from pairs.ir.loops import For
 from pairs.ir.module import ModuleCall
 from pairs.ir.mutator import Mutator
 from pairs.ir.types import Types
-- 
GitLab