Skip to content
Snippets Groups Projects
Commit 4076bda2 authored by Markus Holzer's avatar Markus Holzer
Browse files

Use set instead of list

parent 7f5c7cd0
No related branches found
No related tags found
1 merge request!416Allow index expression for indexed domain kernel
Pipeline #69213 passed
...@@ -72,11 +72,11 @@ def create_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig): ...@@ -72,11 +72,11 @@ def create_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig):
if len(indexed_elements) > 0: if len(indexed_elements) > 0:
common_indexed_element = get_common_indexed_element(indexed_elements) common_indexed_element = get_common_indexed_element(indexed_elements)
index = list(common_indexed_element.indices[0].atoms(TypedSymbol)) index = common_indexed_element.indices[0].atoms(TypedSymbol)
assert len(index) == 1, "index expressions must only contain one symbol representing the index" assert len(index) == 1, "index expressions must only contain one symbol representing the index"
indexing = indexing_creator(iteration_space=(slice(0, common_indexed_element.shape[0], 1), *iteration_space), indexing = indexing_creator(iteration_space=(slice(0, common_indexed_element.shape[0], 1), *iteration_space),
data_layout=common_field.layout) data_layout=common_field.layout)
extended_ctrs = [index[0], *loop_counter_symbols] extended_ctrs = [index.pop(), *loop_counter_symbols]
loop_counter_assignments = indexing.get_loop_ctr_assignments(extended_ctrs) loop_counter_assignments = indexing.get_loop_ctr_assignments(extended_ctrs)
else: else:
indexing = indexing_creator(iteration_space=iteration_space, data_layout=common_field.layout) indexing = indexing_creator(iteration_space=iteration_space, data_layout=common_field.layout)
......
...@@ -276,10 +276,10 @@ def add_outer_loop_over_indexed_elements(loop_node: ast.Block) -> ast.Block: ...@@ -276,10 +276,10 @@ def add_outer_loop_over_indexed_elements(loop_node: ast.Block) -> ast.Block:
if len(indexed_elements) == 0: if len(indexed_elements) == 0:
return loop_node return loop_node
reference_element = get_common_indexed_element(indexed_elements) reference_element = get_common_indexed_element(indexed_elements)
index = list(reference_element.indices[0].atoms(TypedSymbol)) index = reference_element.indices[0].atoms(TypedSymbol)
assert len(index) == 1, "index expressions must only contain one symbol representing the index" assert len(index) == 1, "index expressions must only contain one symbol representing the index"
new_loop = ast.LoopOverCoordinate(loop_node, 0, 0, new_loop = ast.LoopOverCoordinate(loop_node, 0, 0,
reference_element.shape[0], 1, custom_loop_ctr=index[0]) reference_element.shape[0], 1, custom_loop_ctr=index.pop())
return ast.Block([new_loop]) return ast.Block([new_loop])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment