diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py index 996a99ef93b4f499b9555af266573b999f4a1f56..d4552eba7183ab9a5371cb69e6bc6e07b22ed592 100644 --- a/src/pairs/analysis/devices.py +++ b/src/pairs/analysis/devices.py @@ -1,4 +1,5 @@ from pairs.ir.arrays import ArrayAccess +from pairs.ir.branches import Branch from pairs.ir.lit import Lit from pairs.ir.loops import For from pairs.ir.quaternions import QuaternionOp @@ -13,10 +14,25 @@ class MarkCandidateLoops(Visitor): super().__init__(ast) def visit_Module(self, ast_node): - for s in ast_node._block.stmts: - if s is not None: - if isinstance(s, For) and (not isinstance(s.min, Lit) or not isinstance(s.max, Lit)): - s.mark_as_kernel_candidate() + possible_candidates = [] + for stmt in ast_node._block.stmts: + if stmt is not None: + if isinstance(stmt, Branch): + for branch_stmt in stmt.block_if.stmts: + if isinstance(branch_stmt, For): + possible_candidates.append(branch_stmt) + + if stmt.block_else is not None: + for branch_stmt in stmt.block_else.stmts: + if isinstance(branch_stmt, For): + possible_candidates.append(branch_stmt) + + if isinstance(stmt, For): + possible_candidates.append(stmt) + + for stmt in possible_candidates: + if not isinstance(stmt.min, Lit) or not isinstance(stmt.max, Lit): + stmt.mark_as_kernel_candidate() self.visit_children(ast_node) diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py index 9e53b95b439d7f4a4bc6e5239aacd3530d324dbc..e050f9ed52c82b253970ec35c17097b0f710fc93 100644 --- a/src/pairs/transformations/devices.py +++ b/src/pairs/transformations/devices.py @@ -1,7 +1,7 @@ import math from pairs.ir.actions import Actions from pairs.ir.block import Block -from pairs.ir.branches import Filter +from pairs.ir.branches import Branch, Filter from pairs.ir.cast import Cast from pairs.ir.contexts import Contexts from pairs.ir.device import CopyArray, CopyContactProperty, CopyProperty, CopyVar, DeviceStaticRef, HostRef @@ -86,33 +86,72 @@ class AddDeviceCopies(Mutator): class AddDeviceKernels(Mutator): def __init__(self, ast=None): super().__init__(ast) + self._module_name = None + self._kernel_id = 0 - def mutate_Module(self, ast_node): - ast_node._block = self.mutate(ast_node._block) + def create_kernel(self, sim, iterator, rmax, block): + kernel_name = f"{self._module_name}_kernel{self._kernel_id}" + kernel = sim.find_kernel_by_name(kernel_name) + + if kernel is None: + kernel_body = Filter(sim, ScalarOp.inline(iterator < rmax.copy(True)), block) + kernel = Kernel(sim, kernel_name, kernel_body, iterator) + self._kernel_id += 1 + return kernel + + def mutate_Module(self, ast_node): if ast_node.run_on_device: + self._module_name = ast_node.name + self._kernel_id = 0 + new_stmts = [] - kernel_id = 0 - for s in ast_node._block.stmts: - if s is not None: - if isinstance(s, For) and s.is_kernel_candidate(): - kernel_name = f"{ast_node.name}_kernel{kernel_id}" - kernel = ast_node.sim.find_kernel_by_name(kernel_name) - if kernel is None: - kernel_body = Filter(ast_node.sim, - ScalarOp.inline(s.iterator < s.max.copy(True)), - s.block) - - kernel = Kernel(ast_node.sim, kernel_name, kernel_body, s.iterator) - kernel_id += 1 - - new_stmts.append(KernelLaunch(ast_node.sim, kernel, s.iterator, s.min, s.max)) + for stmt in ast_node._block.stmts: + if stmt is not None: + if isinstance(stmt, For) and stmt.is_kernel_candidate(): + kernel = self.create_kernel(ast_node.sim, stmt.iterator, stmt.max, stmt.block) + new_stmts.append( + KernelLaunch(ast_node.sim, kernel, stmt.iterator, stmt.min, stmt.max)) else: - new_stmts.append(s) + if isinstance(stmt, Branch): + stmt = self.check_and_mutate_branch(stmt) + + new_stmts.append(stmt) ast_node._block.stmts = new_stmts + ast_node._block = self.mutate(ast_node._block) + return ast_node + + def check_and_mutate_branch(self, ast_node): + new_stmts = [] + for stmt in ast_node.block_if.stmts: + if stmt is not None: + if isinstance(stmt, For) and stmt.is_kernel_candidate(): + kernel = self.create_kernel(ast_node.sim, stmt.iterator, stmt.max, stmt.block) + new_stmts.append( + KernelLaunch(ast_node.sim, kernel, stmt.iterator, stmt.min, stmt.max)) + + else: + new_stmts.append(stmt) + + ast_node.block_if.stmts = new_stmts + + if ast_node.block_else is not None: + new_stmts = [] + for stmt in ast_node.block_else.stmts: + if stmt is not None: + if isinstance(stmt, For) and stmt.is_kernel_candidate(): + kernel = self.create_kernel(ast_node.sim, stmt.iterator, stmt.max, stmt.block) + new_stmts.append( + KernelLaunch(ast_node.sim, kernel, stmt.iterator, stmt.min, stmt.max)) + + else: + new_stmts.append(stmt) + + ast_node.block_else.stmts = new_stmts + return ast_node