Skip to content
Snippets Groups Projects
Commit ec2f0817 authored by Markus Holzer's avatar Markus Holzer
Browse files

test

parent c124df93
Branches
No related tags found
No related merge requests found
...@@ -270,6 +270,7 @@ class KernelFunction(Node): ...@@ -270,6 +270,7 @@ class KernelFunction(Node):
parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols] parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
if hasattr(self, 'indexing'): if hasattr(self, 'indexing'):
parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()] parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
parameters = [p for p in parameters if p.symbol.name != "svcntd()"]
parameters.sort(key=lambda p: p.symbol.name) parameters.sort(key=lambda p: p.symbol.name)
return parameters return parameters
......
...@@ -72,9 +72,11 @@ def create_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig): ...@@ -72,9 +72,11 @@ def create_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig):
if len(indexed_elements) > 0: if len(indexed_elements) > 0:
common_indexed_element = get_common_indexed_element(indexed_elements) common_indexed_element = get_common_indexed_element(indexed_elements)
index = list(common_indexed_element.indices[0].atoms(TypedSymbol))
assert len(index) == 1, "index expressions must only contain one symbol representing the index"
indexing = indexing_creator(iteration_space=(slice(0, common_indexed_element.shape[0], 1), *iteration_space), indexing = indexing_creator(iteration_space=(slice(0, common_indexed_element.shape[0], 1), *iteration_space),
data_layout=common_field.layout) data_layout=common_field.layout)
extended_ctrs = [common_indexed_element.indices[0], *loop_counter_symbols] extended_ctrs = [index[0], *loop_counter_symbols]
loop_counter_assignments = indexing.get_loop_ctr_assignments(extended_ctrs) loop_counter_assignments = indexing.get_loop_ctr_assignments(extended_ctrs)
else: else:
indexing = indexing_creator(iteration_space=iteration_space, data_layout=common_field.layout) indexing = indexing_creator(iteration_space=iteration_space, data_layout=common_field.layout)
......
...@@ -276,8 +276,11 @@ def add_outer_loop_over_indexed_elements(loop_node: ast.Block) -> ast.Block: ...@@ -276,8 +276,11 @@ def add_outer_loop_over_indexed_elements(loop_node: ast.Block) -> ast.Block:
if len(indexed_elements) == 0: if len(indexed_elements) == 0:
return loop_node return loop_node
reference_element = get_common_indexed_element(indexed_elements) reference_element = get_common_indexed_element(indexed_elements)
index = list(reference_element.indices[0].atoms(TypedSymbol))
assert len(index) == 1, "index expressions must only contain one symbol representing the index"
new_loop = ast.LoopOverCoordinate(loop_node, 0, 0, new_loop = ast.LoopOverCoordinate(loop_node, 0, 0,
reference_element.shape[0], 1, custom_loop_ctr=reference_element.indices[0]) reference_element.shape[0], 1, custom_loop_ctr=index[0])
return ast.Block([new_loop]) return ast.Block([new_loop])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment