Skip to content
Snippets Groups Projects

Add CUDA support

Merged Markus Holzer requested to merge CUDA into master
1 unresolved thread
1 file
+ 9
4
Compare changes
  • Side-by-side
  • Inline
@@ -16,15 +16,17 @@ def _kernel_header(kernel_ast: KernelFunction,
@@ -16,15 +16,17 @@ def _kernel_header(kernel_ast: KernelFunction,
dialect: Backend = Backend.C,
dialect: Backend = Backend.C,
*,
*,
template_file: str,
template_file: str,
additional_jinja_context: dict = {}) -> str:
additional_jinja_context: dict = None) -> str:
 
function_signature = generate_c(kernel_ast, dialect=dialect, signature_only=True)
function_signature = generate_c(kernel_ast, dialect=dialect, signature_only=True)
header_guard = f'_{kernel_ast.function_name.upper()}_H'
header_guard = f'_{kernel_ast.function_name.upper()}_H'
jinja_context = {
jinja_context = {
'header_guard': header_guard,
'header_guard': header_guard,
'function_signature': function_signature,
'function_signature': function_signature,
**additional_jinja_context
}
}
 
if additional_jinja_context is not None:
 
jinja_context.update(additional_jinja_context)
header = _env.get_template(template_file).render(**jinja_context)
header = _env.get_template(template_file).render(**jinja_context)
return header
return header
@@ -34,7 +36,8 @@ def _kernel_source(kernel_ast: KernelFunction,
@@ -34,7 +36,8 @@ def _kernel_source(kernel_ast: KernelFunction,
dialect: Backend = Backend.C,
dialect: Backend = Backend.C,
*,
*,
template_file: str,
template_file: str,
additional_jinja_context: dict = {}) -> str:
additional_jinja_context: dict = None) -> str:
 
kernel_name = kernel_ast.function_name
kernel_name = kernel_ast.function_name
function_source = generate_c(kernel_ast, dialect=dialect)
function_source = generate_c(kernel_ast, dialect=dialect)
headers = {f'"{kernel_name}.h"', '<math.h>', '<stdint.h>'}
headers = {f'"{kernel_name}.h"', '<math.h>', '<stdint.h>'}
@@ -44,9 +47,11 @@ def _kernel_source(kernel_ast: KernelFunction,
@@ -44,9 +47,11 @@ def _kernel_source(kernel_ast: KernelFunction,
'function_source': function_source,
'function_source': function_source,
'headers': sorted(headers),
'headers': sorted(headers),
'timing': True,
'timing': True,
**additional_jinja_context,
}
}
 
if additional_jinja_context is not None:
 
jinja_context.update(additional_jinja_context)
 
source = _env.get_template(template_file).render(**jinja_context)
source = _env.get_template(template_file).render(**jinja_context)
return source
return source
Loading