diff --git a/pystencils/nbackend/c_printer.py b/pystencils/nbackend/c_printer.py index 4ca472a2996851a4af7cb1bf0e8c58866954e721..6872ad9dfcf2fb9981bdaa4f3aca901c4007c63b 100644 --- a/pystencils/nbackend/c_printer.py +++ b/pystencils/nbackend/c_printer.py @@ -29,7 +29,7 @@ class CPrinter: def function(self, func: PsKernelFunction) -> str: params = func.get_parameters() params_str = ", ".join(f"{p.dtype} {p.name}" for p in params) - decl = f"FUNC_PREFIX void {func.name} ( {params_str} )" + decl = f"FUNC_PREFIX void {func.name} ({params_str})" body = self.visit(func.body) return f"{decl}\n{body}" diff --git a/pystencils_tests/nbackend/test_basic_printing.py b/pystencils_tests/nbackend/test_basic_printing.py index 2394b9287099f1e31a56cd97b7370e9ea4185157..8679211146633fdd48bd8aadf0128cfe1fb9d012 100644 --- a/pystencils_tests/nbackend/test_basic_printing.py +++ b/pystencils_tests/nbackend/test_basic_printing.py @@ -34,5 +34,10 @@ def test_basic_kernel(): printer = CPrinter() code = printer.print(func) - assert code.find("u_data[ctr] = u_data[ctr + 1] + u_data[ctr - 1]") >= 0 + paramlist = func.get_parameters() + params_str = ", ".join(f"{p.dtype} {p.name}" for p in paramlist) + + assert code.find("(" + params_str + ")") >= 0 + + assert code.find("u_data[ctr] = u_data[ctr - 1] + u_data[ctr + 1];") >= 0