diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index a7f39967dde062baed774597c71a7b32044fb4f5..f6827b830af5011be5c4e88566a2bd756c3b0b35 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -144,6 +144,7 @@ class CGen: kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" self.print(f"__global__ void {kernel.name}({kernel_params}) {{") + self.print(f" const int {kernel.iterator.name()} = blockIdx.x * blockDim.x + threadIdx.x;") self.generate_statement(kernel.block) self.print("}") @@ -270,19 +271,19 @@ class CGen: if isinstance(ast_node, KernelLaunch): kernel = ast_node.kernel kernel_params = "" - for var in module.read_only_variables(): + for var in kernel.read_only_variables(): decl = var.name() kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" - for var in module.write_variables(): + for var in kernel.write_variables(): decl = f"&{var.name()}" kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" - for array in module.arrays(): + for array in kernel.arrays(): decl = f"d_{array.name()}" kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" - for prop in module.properties(): + for prop in kernel.properties(): decl = f"d_{prop.name()}" kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" @@ -290,10 +291,9 @@ class CGen: decl = self.generate_expression(bin_op) kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" - elems = ast_node.kernel.max - ast_node.kernel.min - threads_per_block = self.generate_expression(ast_node.kernel.threads_per_block) - blocks = self.generate_expression((elems + threads_per_block - 1) // threads_per_block) - self.print(f"{kernel.name}<<<{blocks}, {threads_per_block}>>>({kernel_params});") + elems = ast_node.max - ast_node.min + blocks = self.generate_expression((elems + ast_node.threads_per_block - 1) / ast_node.threads_per_block) + self.print(f"{kernel.name}<<<{blocks}, {ast_node.threads_per_block}>>>({kernel_params});") if isinstance(ast_node, ModuleCall): module = ast_node.module diff --git a/src/pairs/ir/kernel.py b/src/pairs/ir/kernel.py index b1cdf0f7eacb5f89ea2a08889d90a50caa5873b5..334f707bf757968f7006e861bc069597ff5d5aea 100644 --- a/src/pairs/ir/kernel.py +++ b/src/pairs/ir/kernel.py @@ -1,5 +1,6 @@ from pairs.ir.arrays import Array from pairs.ir.ast_node import ASTNode +from pairs.ir.bin_op import BinOp from pairs.ir.properties import Property from pairs.ir.variables import Var @@ -7,7 +8,7 @@ from pairs.ir.variables import Var class Kernel(ASTNode): last_kernel = 0 - def __init__(self, sim, name=None, block=None): + def __init__(self, sim, name=None, block=None, iterator=None): super().__init__(sim) self._id = Kernel.last_kernel self._name = name if name is not None else "kernel" + str(Kernel.last_kernel) @@ -16,6 +17,7 @@ class Kernel(ASTNode): self._properties = {} self._bin_ops = [] self._block = block + self._iterator = iterator sim.add_kernel(self) Kernel.last_kernel += 1 @@ -31,6 +33,10 @@ class Kernel(ASTNode): def block(self): return self._block + @property + def iterator(self): + return self._iterator + def variables(self): return self._variables diff --git a/src/pairs/transformations/__init__.py b/src/pairs/transformations/__init__.py index 08cf0e01571d3c84b7cee475d7e4390da5266149..8615eb2d46cb98bce1404ba584c74d7ab3d51ac1 100644 --- a/src/pairs/transformations/__init__.py +++ b/src/pairs/transformations/__init__.py @@ -61,8 +61,8 @@ class Transformations: def add_device_kernels(self): if self._target.is_gpu(): - self._analysis.fetch_kernel_references() self._add_device_kernels.mutate() + self._analysis.fetch_kernel_references() def apply_all(self): self.lower_everything() diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py index fec459a24739bcdf52eb3a560a4f8410d6c4afc8..ba4ac4021523ff96733f4a2addf21ffa1e621f75 100644 --- a/src/pairs/transformations/devices.py +++ b/src/pairs/transformations/devices.py @@ -80,12 +80,12 @@ class AddDeviceKernels(Mutator): for s in ast_node._block.stmts: if s is not None: if isinstance(s, For) and (not isinstance(s.min, Lit) or not isinstance(s.max, Lit)): - kernel = Kernel(ast_node.sim, f"{ast_node.name}_kernel{kernel_id}", s.block) + kernel = Kernel(ast_node.sim, f"{ast_node.name}_kernel{kernel_id}", s.block, s.iterator) new_stmts.append(KernelLaunch(ast_node.sim, kernel, s.iterator, s.min, s.max)) kernel_id += 1 else: new_stmts.append(s) - ast_node._block_stmts = new_stmts + ast_node._block.stmts = new_stmts return ast_node