From 0a858643e1f55d7c7199820eb84080b6598008e5 Mon Sep 17 00:00:00 2001 From: Rafael Ravedutti <rafaelravedutti@gmail.com> Date: Thu, 31 Mar 2022 01:21:42 +0200 Subject: [PATCH] Perform transformations in all modules and kernels Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com> --- src/pairs/ir/kernel.py | 5 ++++- src/pairs/ir/module.py | 3 +++ src/pairs/ir/mutator.py | 15 +++++++++++++++ src/pairs/sim/simulation.py | 1 + src/pairs/transformations/devices.py | 3 +++ 5 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/pairs/ir/kernel.py b/src/pairs/ir/kernel.py index f1040ec..b1cdf0f 100644 --- a/src/pairs/ir/kernel.py +++ b/src/pairs/ir/kernel.py @@ -88,7 +88,7 @@ class Kernel(ASTNode): class KernelLaunch(ASTNode): def __init__(self, sim, kernel, iterator, range_min, range_max): - assert isinstance(module, Kernel), "KernelLaunch(): given parameter is not of type Kernel!" + assert isinstance(kernel, Kernel), "KernelLaunch(): given parameter is not of type Kernel!" super().__init__(sim) self._kernel = kernel self._iterator = iterator @@ -111,3 +111,6 @@ class KernelLaunch(ASTNode): @property def threads_per_block(self): return self._threads_per_block + + def children(self): + return [self._kernel, self._iterator, self._range_min, self._range_max] diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py index 7353ddd..da2f225 100644 --- a/src/pairs/ir/module.py +++ b/src/pairs/ir/module.py @@ -87,3 +87,6 @@ class ModuleCall(ASTNode): @property def module(self): return self._module + + def children(self): + return [self._module] diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py index 7118b07..e50e209 100644 --- a/src/pairs/ir/mutator.py +++ b/src/pairs/ir/mutator.py @@ -66,6 +66,17 @@ class Mutator: ast_node.block = self.mutate(ast_node.block) return ast_node + def mutate_Kernel(self, ast_node): + ast_node._block = self.mutate(ast_node._block) + return ast_node + + def mutate_KernelLaunch(self, ast_node): + ast_node._kernel = self.mutate(ast_node._kernel) + ast_node._iterator = self.mutate(ast_node._iterator) + ast_node._range_min = self.mutate(ast_node._range_min) + ast_node._range_max = self.mutate(ast_node._range_max) + return ast_node + def mutate_ParticleFor(self, ast_node): return self.mutate_For(ast_node) @@ -84,6 +95,10 @@ class Mutator: ast_node._block = self.mutate(ast_node._block) return ast_node + def mutate_ModuleCall(self, ast_node): + ast_node._module = self.mutate(ast_node._module) + return ast_node + def mutate_Realloc(self, ast_node): ast_node.array = self.mutate(ast_node.array) ast_node.size = self.mutate(ast_node.size) diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index e8bd7b2..583bbb3 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -1,6 +1,7 @@ from pairs.ir.arrays import Arrays from pairs.ir.block import Block from pairs.ir.branches import Filter +from pairs.ir.kernel import Kernel from pairs.ir.layouts import Layouts from pairs.ir.module import Module from pairs.ir.properties import Properties diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py index e493d94..fec459a 100644 --- a/src/pairs/transformations/devices.py +++ b/src/pairs/transformations/devices.py @@ -4,6 +4,9 @@ from pairs.ir.bin_op import BinOp from pairs.ir.block import Block from pairs.ir.branches import Filter from pairs.ir.device import CopyToDevice, CopyToHost +from pairs.ir.kernel import Kernel, KernelLaunch +from pairs.ir.lit import Lit +from pairs.ir.loops import For from pairs.ir.module import ModuleCall from pairs.ir.mutator import Mutator from pairs.ir.types import Types -- GitLab