Skip to content
Snippets Groups Projects
Commit a5e5f2c3 authored by Martin Bauer's avatar Martin Bauer
Browse files

walberla boundary class now also works for GPU boundaries

parent c2a52c28
Branches
Tags
No related merge requests found
...@@ -111,7 +111,7 @@ def field_extraction_code(field_accesses, field_name, is_temporary, declaration_ ...@@ -111,7 +111,7 @@ def field_extraction_code(field_accesses, field_name, is_temporary, declaration_
return "%s * %s;" % (field_type, field_name) return "%s * %s;" % (field_type, field_name)
else: else:
prefix = "" if no_declaration else "auto " prefix = "" if no_declaration else "auto "
return "%s%s = block->uncheckedFastGetData< %s >(%sID);" % (prefix, field_name, field_type, field_name) return "%s%s = block->getData< %s >(%sID);" % (prefix, field_name, field_type, field_name)
else: else:
assert field_name.endswith('_tmp') assert field_name.endswith('_tmp')
original_field_name = field_name[:-len('_tmp')] original_field_name = field_name[:-len('_tmp')]
...@@ -170,7 +170,7 @@ def generate_refs_for_kernel_parameters(kernel_info, prefix, parameters_to_ignor ...@@ -170,7 +170,7 @@ def generate_refs_for_kernel_parameters(kernel_info, prefix, parameters_to_ignor
@jinja2.contextfilter @jinja2.contextfilter
def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=None, stream='0'): def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=None, stream='0', spatial_shape_symbols=[]):
"""Generates the function call to a pystencils kernel """Generates the function call to a pystencils kernel
Args: Args:
...@@ -290,11 +290,14 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non ...@@ -290,11 +290,14 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non
kernel_call_lines += create_field_shape_code(fields[param.field_name], param.name) kernel_call_lines += create_field_shape_code(fields[param.field_name], param.name)
if not is_cpu: if not is_cpu:
spatial_shape_symbols = [] if not spatial_shape_symbols:
for param in ast_params: spatial_shape_symbols = []
if param.is_field_shape_argument: for param in ast_params:
spatial_shape_symbols = [TypedSymbol("%s_cpu[%d]" % (param.name, i), get_base_type(Field.SHAPE_DTYPE)) if param.is_field_shape_argument:
for i in range(field.spatial_dimensions)] spatial_shape_symbols = [TypedSymbol("%s_cpu[%d]" % (param.name, i), get_base_type(Field.SHAPE_DTYPE))
for i in range(field.spatial_dimensions)]
else:
spatial_shape_symbols = [TypedSymbol(e, get_base_type(Field.SHAPE_DTYPE)) for e in spatial_shape_symbols]
if not spatial_shape_symbols: if not spatial_shape_symbols:
for param in ast_params: for param in ast_params:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment