Skip to content
Snippets Groups Projects
Commit 0a858643 authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Perform transformations in all modules and kernels

parent f1914aca
Branches
Tags
No related merge requests found
...@@ -88,7 +88,7 @@ class Kernel(ASTNode): ...@@ -88,7 +88,7 @@ class Kernel(ASTNode):
class KernelLaunch(ASTNode): class KernelLaunch(ASTNode):
def __init__(self, sim, kernel, iterator, range_min, range_max): def __init__(self, sim, kernel, iterator, range_min, range_max):
assert isinstance(module, Kernel), "KernelLaunch(): given parameter is not of type Kernel!" assert isinstance(kernel, Kernel), "KernelLaunch(): given parameter is not of type Kernel!"
super().__init__(sim) super().__init__(sim)
self._kernel = kernel self._kernel = kernel
self._iterator = iterator self._iterator = iterator
...@@ -111,3 +111,6 @@ class KernelLaunch(ASTNode): ...@@ -111,3 +111,6 @@ class KernelLaunch(ASTNode):
@property @property
def threads_per_block(self): def threads_per_block(self):
return self._threads_per_block return self._threads_per_block
def children(self):
return [self._kernel, self._iterator, self._range_min, self._range_max]
...@@ -87,3 +87,6 @@ class ModuleCall(ASTNode): ...@@ -87,3 +87,6 @@ class ModuleCall(ASTNode):
@property @property
def module(self): def module(self):
return self._module return self._module
def children(self):
return [self._module]
...@@ -66,6 +66,17 @@ class Mutator: ...@@ -66,6 +66,17 @@ class Mutator:
ast_node.block = self.mutate(ast_node.block) ast_node.block = self.mutate(ast_node.block)
return ast_node return ast_node
def mutate_Kernel(self, ast_node):
ast_node._block = self.mutate(ast_node._block)
return ast_node
def mutate_KernelLaunch(self, ast_node):
ast_node._kernel = self.mutate(ast_node._kernel)
ast_node._iterator = self.mutate(ast_node._iterator)
ast_node._range_min = self.mutate(ast_node._range_min)
ast_node._range_max = self.mutate(ast_node._range_max)
return ast_node
def mutate_ParticleFor(self, ast_node): def mutate_ParticleFor(self, ast_node):
return self.mutate_For(ast_node) return self.mutate_For(ast_node)
...@@ -84,6 +95,10 @@ class Mutator: ...@@ -84,6 +95,10 @@ class Mutator:
ast_node._block = self.mutate(ast_node._block) ast_node._block = self.mutate(ast_node._block)
return ast_node return ast_node
def mutate_ModuleCall(self, ast_node):
ast_node._module = self.mutate(ast_node._module)
return ast_node
def mutate_Realloc(self, ast_node): def mutate_Realloc(self, ast_node):
ast_node.array = self.mutate(ast_node.array) ast_node.array = self.mutate(ast_node.array)
ast_node.size = self.mutate(ast_node.size) ast_node.size = self.mutate(ast_node.size)
......
from pairs.ir.arrays import Arrays from pairs.ir.arrays import Arrays
from pairs.ir.block import Block from pairs.ir.block import Block
from pairs.ir.branches import Filter from pairs.ir.branches import Filter
from pairs.ir.kernel import Kernel
from pairs.ir.layouts import Layouts from pairs.ir.layouts import Layouts
from pairs.ir.module import Module from pairs.ir.module import Module
from pairs.ir.properties import Properties from pairs.ir.properties import Properties
......
...@@ -4,6 +4,9 @@ from pairs.ir.bin_op import BinOp ...@@ -4,6 +4,9 @@ from pairs.ir.bin_op import BinOp
from pairs.ir.block import Block from pairs.ir.block import Block
from pairs.ir.branches import Filter from pairs.ir.branches import Filter
from pairs.ir.device import CopyToDevice, CopyToHost from pairs.ir.device import CopyToDevice, CopyToHost
from pairs.ir.kernel import Kernel, KernelLaunch
from pairs.ir.lit import Lit
from pairs.ir.loops import For
from pairs.ir.module import ModuleCall from pairs.ir.module import ModuleCall
from pairs.ir.mutator import Mutator from pairs.ir.mutator import Mutator
from pairs.ir.types import Types from pairs.ir.types import Types
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment