Skip to content
Snippets Groups Projects
Commit 857f1848 authored by Christoph Alt's avatar Christoph Alt
Browse files

removed the mutable default argument from the _kernel_header and

_kernel_source function
parent 24f81cf6
No related branches found
No related tags found
1 merge request!1Add CUDA support
Pipeline #54965 skipped
...@@ -16,15 +16,17 @@ def _kernel_header(kernel_ast: KernelFunction, ...@@ -16,15 +16,17 @@ def _kernel_header(kernel_ast: KernelFunction,
dialect: Backend = Backend.C, dialect: Backend = Backend.C,
*, *,
template_file: str, template_file: str,
additional_jinja_context: dict = {}) -> str: additional_jinja_context: dict = None) -> str:
function_signature = generate_c(kernel_ast, dialect=dialect, signature_only=True) function_signature = generate_c(kernel_ast, dialect=dialect, signature_only=True)
header_guard = f'_{kernel_ast.function_name.upper()}_H' header_guard = f'_{kernel_ast.function_name.upper()}_H'
jinja_context = { jinja_context = {
'header_guard': header_guard, 'header_guard': header_guard,
'function_signature': function_signature, 'function_signature': function_signature,
**additional_jinja_context
} }
if additional_jinja_context is not None:
jinja_context.update(additional_jinja_context)
header = _env.get_template(template_file).render(**jinja_context) header = _env.get_template(template_file).render(**jinja_context)
return header return header
...@@ -34,7 +36,8 @@ def _kernel_source(kernel_ast: KernelFunction, ...@@ -34,7 +36,8 @@ def _kernel_source(kernel_ast: KernelFunction,
dialect: Backend = Backend.C, dialect: Backend = Backend.C,
*, *,
template_file: str, template_file: str,
additional_jinja_context: dict = {}) -> str: additional_jinja_context: dict = None) -> str:
kernel_name = kernel_ast.function_name kernel_name = kernel_ast.function_name
function_source = generate_c(kernel_ast, dialect=dialect) function_source = generate_c(kernel_ast, dialect=dialect)
headers = {f'"{kernel_name}.h"', '<math.h>', '<stdint.h>'} headers = {f'"{kernel_name}.h"', '<math.h>', '<stdint.h>'}
...@@ -44,9 +47,11 @@ def _kernel_source(kernel_ast: KernelFunction, ...@@ -44,9 +47,11 @@ def _kernel_source(kernel_ast: KernelFunction,
'function_source': function_source, 'function_source': function_source,
'headers': sorted(headers), 'headers': sorted(headers),
'timing': True, 'timing': True,
**additional_jinja_context,
} }
if additional_jinja_context is not None:
jinja_context.update(additional_jinja_context)
source = _env.get_template(template_file).render(**jinja_context) source = _env.get_template(template_file).render(**jinja_context)
return source return source
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment