diff --git a/src/pairs/ir/kernel.py b/src/pairs/ir/kernel.py index f1040ecd344f9a869693682c1cac6d87cbf460d7..b1cdf0f7eacb5f89ea2a08889d90a50caa5873b5 100644 --- a/src/pairs/ir/kernel.py +++ b/src/pairs/ir/kernel.py @@ -88,7 +88,7 @@ class Kernel(ASTNode): class KernelLaunch(ASTNode): def __init__(self, sim, kernel, iterator, range_min, range_max): - assert isinstance(module, Kernel), "KernelLaunch(): given parameter is not of type Kernel!" + assert isinstance(kernel, Kernel), "KernelLaunch(): given parameter is not of type Kernel!" super().__init__(sim) self._kernel = kernel self._iterator = iterator @@ -111,3 +111,6 @@ class KernelLaunch(ASTNode): @property def threads_per_block(self): return self._threads_per_block + + def children(self): + return [self._kernel, self._iterator, self._range_min, self._range_max] diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py index 7353dddcc937f86db962bdb5b282321c3a1dd540..da2f225810dca2c83b135c05259125b34b06ef3d 100644 --- a/src/pairs/ir/module.py +++ b/src/pairs/ir/module.py @@ -87,3 +87,6 @@ class ModuleCall(ASTNode): @property def module(self): return self._module + + def children(self): + return [self._module] diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py index 7118b0785b61e805a4f48bab37cdc97d4b6a271d..e50e209c1320a47681fbb406b781e1b3a86b0ba2 100644 --- a/src/pairs/ir/mutator.py +++ b/src/pairs/ir/mutator.py @@ -66,6 +66,17 @@ class Mutator: ast_node.block = self.mutate(ast_node.block) return ast_node + def mutate_Kernel(self, ast_node): + ast_node._block = self.mutate(ast_node._block) + return ast_node + + def mutate_KernelLaunch(self, ast_node): + ast_node._kernel = self.mutate(ast_node._kernel) + ast_node._iterator = self.mutate(ast_node._iterator) + ast_node._range_min = self.mutate(ast_node._range_min) + ast_node._range_max = self.mutate(ast_node._range_max) + return ast_node + def mutate_ParticleFor(self, ast_node): return self.mutate_For(ast_node) @@ -84,6 +95,10 @@ class Mutator: ast_node._block = self.mutate(ast_node._block) return ast_node + def mutate_ModuleCall(self, ast_node): + ast_node._module = self.mutate(ast_node._module) + return ast_node + def mutate_Realloc(self, ast_node): ast_node.array = self.mutate(ast_node.array) ast_node.size = self.mutate(ast_node.size) diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index e8bd7b25fac2d97b211a02ea20172f075b816ada..583bbb32d5721d2bacc3a859e644a768be79df42 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -1,6 +1,7 @@ from pairs.ir.arrays import Arrays from pairs.ir.block import Block from pairs.ir.branches import Filter +from pairs.ir.kernel import Kernel from pairs.ir.layouts import Layouts from pairs.ir.module import Module from pairs.ir.properties import Properties diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py index e493d941e0a00ada75c7a6a5f7432a77da2e5e3e..fec459a24739bcdf52eb3a560a4f8410d6c4afc8 100644 --- a/src/pairs/transformations/devices.py +++ b/src/pairs/transformations/devices.py @@ -4,6 +4,9 @@ from pairs.ir.bin_op import BinOp from pairs.ir.block import Block from pairs.ir.branches import Filter from pairs.ir.device import CopyToDevice, CopyToHost +from pairs.ir.kernel import Kernel, KernelLaunch +from pairs.ir.lit import Lit +from pairs.ir.loops import For from pairs.ir.module import ModuleCall from pairs.ir.mutator import Mutator from pairs.ir.types import Types