diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 516479b7ef75a1d9aa014ac13f2b68f88108001f..270a8b1fbdb3c42f83f02b65646cb714a5374731 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -392,7 +392,11 @@ class CGen: threads_per_block = self.generate_expression(ast_node.threads_per_block) nblocks = self.generate_expression(ast_node.nblocks) + self.print(f"if({nblocks} > 0 && {threads_per_block} > 0) {{") + self.print.add_indent(4) self.print(f"{kernel.name}<<<{nblocks}, {threads_per_block}>>>({kernel_params});") + self.print.add_indent(-4) + self.print("}") if isinstance(ast_node, ModuleCall): module = ast_node.module