diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index ec4018243e141250350dc7a5b01585209c9f0f98..285f147d150124ac53e8f27f79ced34d8de90b01 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -142,39 +142,39 @@ class CGen: self.print("}") def generate_kernel(self, kernel): - kernel_params = "" + kernel_params = "int range_start" for var in kernel.read_only_variables(): type_kw = Types.c_keyword(var.type()) decl = f"{type_kw} {var.name()}" - kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" + kernel_params += f", {decl}" for var in kernel.write_variables(): type_kw = Types.c_keyword(var.type()) decl = f"{type_kw} *{var.name()}" - kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" + kernel_params += f", {decl}" for array in kernel.arrays(): type_kw = Types.c_keyword(array.type()) decl = f"{type_kw} *{array.name()}" - kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" + kernel_params += f", {decl}" for prop in kernel.properties(): type_kw = Types.c_keyword(prop.type()) decl = f"{type_kw} *{prop.name()}" - kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" + kernel_params += f", {decl}" for array_access in kernel.array_accesses(): type_kw = Types.c_keyword(array_access.type()) decl = f"{type_kw} a{array_access.id()}" - kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" + kernel_params += f", {decl}" for bin_op in kernel.bin_ops(): type_kw = Types.c_keyword(bin_op.type()) decl = f"{type_kw} e{bin_op.id()}" - kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" + kernel_params += f", {decl}" self.print(f"__global__ void {kernel.name}({kernel_params}) {{") - self.print(f" const int {kernel.iterator.name()} = blockIdx.x * blockDim.x + threadIdx.x;") + self.print(f" const int {kernel.iterator.name()} = blockIdx.x * blockDim.x + threadIdx.x + range_start;") self.print.add_indent(4) self.generate_statement(kernel.block) self.print.add_indent(-4) @@ -365,31 +365,27 @@ class CGen: self.print(f"d_{array_name} = ({tkw} *) pairs::device_alloc({size});") if isinstance(ast_node, KernelLaunch): + range_start = self.generate_expression(BinOp.inline(ast_node.min)) kernel = ast_node.kernel - kernel_params = "" + kernel_params = f"{range_start}" + for var in kernel.read_only_variables(): - decl = var.name() - kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" + kernel_params += f", {var.name()}" for var in kernel.write_variables(): - decl = var.name() - kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" + kernel_params += f", {var.name()}" for array in kernel.arrays(): - decl = array.name() - kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" + kernel_params += f", {array.name()}" for prop in kernel.properties(): - decl = prop.name() - kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" + kernel_params += f", {prop.name()}" for array_access in kernel.array_accesses(): - decl = self.generate_expression(array_access) - kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" + kernel_params += f", {self.generate_expression(array_access)}" for bin_op in kernel.bin_ops(): - decl = self.generate_expression(bin_op) - kernel_params += decl if len(kernel_params) <= 0 else f", {decl}" + kernel_params += f", {self.generate_expression(bin_op)}" threads_per_block = self.generate_expression(ast_node.threads_per_block) nblocks = self.generate_expression(ast_node.nblocks)