diff --git a/pystencils_benchmark/common.py b/pystencils_benchmark/common.py index beeeed69cf7225dc91ad45cb8ad833ca1cdf10e7..70cabd6dbe329b0a92c6a906dff6b262698508da 100644 --- a/pystencils_benchmark/common.py +++ b/pystencils_benchmark/common.py @@ -16,15 +16,17 @@ def _kernel_header(kernel_ast: KernelFunction, dialect: Backend = Backend.C, *, template_file: str, - additional_jinja_context: dict = {}) -> str: + additional_jinja_context: dict = None) -> str: + function_signature = generate_c(kernel_ast, dialect=dialect, signature_only=True) header_guard = f'_{kernel_ast.function_name.upper()}_H' jinja_context = { 'header_guard': header_guard, 'function_signature': function_signature, - **additional_jinja_context } + if additional_jinja_context is not None: + jinja_context.update(additional_jinja_context) header = _env.get_template(template_file).render(**jinja_context) return header @@ -34,7 +36,8 @@ def _kernel_source(kernel_ast: KernelFunction, dialect: Backend = Backend.C, *, template_file: str, - additional_jinja_context: dict = {}) -> str: + additional_jinja_context: dict = None) -> str: + kernel_name = kernel_ast.function_name function_source = generate_c(kernel_ast, dialect=dialect) headers = {f'"{kernel_name}.h"', '<math.h>', '<stdint.h>'} @@ -44,9 +47,11 @@ def _kernel_source(kernel_ast: KernelFunction, 'function_source': function_source, 'headers': sorted(headers), 'timing': True, - **additional_jinja_context, } + if additional_jinja_context is not None: + jinja_context.update(additional_jinja_context) + source = _env.get_template(template_file).render(**jinja_context) return source