diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index 2ed480c5566493b5dbe109093bb1ba79ae4ffe2f..495b5b976c39fb5b7c8f9efb860dfe1464fb0589 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -89,6 +89,11 @@ class Simulation: def kernels(self): return self.kernel_list + def find_kernel_by_name(self, name): + matches = [k for k in self.kernel_list if k.name == name] + assert len(matches) < 2, "find_kernel_by_name(): More than one match for kernel name!" + return matches[0] if len(matches) == 1 else None + def ndims(self): return self.dims diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py index a5b18f647033317709e36044c71226706c3f8964..b77a0d45b818d33dfd56f6cb192df5d09cfa34c8 100644 --- a/src/pairs/transformations/devices.py +++ b/src/pairs/transformations/devices.py @@ -72,12 +72,14 @@ class AddDeviceKernels(Mutator): for s in ast_node._block.stmts: if s is not None: if isinstance(s, For) and (not isinstance(s.min, Lit) or not isinstance(s.max, Lit)): - if s.kernel is None: - s.kernel = Kernel(ast_node.sim, f"{ast_node.name}_kernel{kernel_id}", - Filter(ast_node.sim, BinOp.inline(s.iterator < s.max), s.block), s.iterator) + kernel_name = f"{ast_node.name}_kernel{kernel_id}" + kernel = ast_node.sim.find_kernel_by_name(kernel_name) + if kernel is None: + kernel_body = Filter(ast_node.sim, BinOp.inline(s.iterator < s.max), s.block) + kernel = Kernel(ast_node.sim, kernel_name, kernel_body, s.iterator) kernel_id += 1 - new_stmts.append(KernelLaunch(ast_node.sim, s.kernel, s.iterator, s.min, s.max)) + new_stmts.append(KernelLaunch(ast_node.sim, kernel, s.iterator, s.min, s.max)) else: new_stmts.append(s)