From ec2f081739093661330e4d43f1932607b78863c9 Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Thu, 8 Aug 2024 22:18:04 +0200 Subject: [PATCH] test --- src/pystencils/astnodes.py | 1 + src/pystencils/gpu/kernelcreation.py | 4 +++- src/pystencils/transformations.py | 5 ++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/pystencils/astnodes.py b/src/pystencils/astnodes.py index f399287e..f0755a41 100644 --- a/src/pystencils/astnodes.py +++ b/src/pystencils/astnodes.py @@ -270,6 +270,7 @@ class KernelFunction(Node): parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols] if hasattr(self, 'indexing'): parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()] + parameters = [p for p in parameters if p.symbol.name != "svcntd()"] parameters.sort(key=lambda p: p.symbol.name) return parameters diff --git a/src/pystencils/gpu/kernelcreation.py b/src/pystencils/gpu/kernelcreation.py index c2e6143b..c7670c2d 100644 --- a/src/pystencils/gpu/kernelcreation.py +++ b/src/pystencils/gpu/kernelcreation.py @@ -72,9 +72,11 @@ def create_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig): if len(indexed_elements) > 0: common_indexed_element = get_common_indexed_element(indexed_elements) + index = list(common_indexed_element.indices[0].atoms(TypedSymbol)) + assert len(index) == 1, "index expressions must only contain one symbol representing the index" indexing = indexing_creator(iteration_space=(slice(0, common_indexed_element.shape[0], 1), *iteration_space), data_layout=common_field.layout) - extended_ctrs = [common_indexed_element.indices[0], *loop_counter_symbols] + extended_ctrs = [index[0], *loop_counter_symbols] loop_counter_assignments = indexing.get_loop_ctr_assignments(extended_ctrs) else: indexing = indexing_creator(iteration_space=iteration_space, data_layout=common_field.layout) diff --git a/src/pystencils/transformations.py b/src/pystencils/transformations.py index 79c24d14..2a921031 100644 --- a/src/pystencils/transformations.py +++ b/src/pystencils/transformations.py @@ -276,8 +276,11 @@ def add_outer_loop_over_indexed_elements(loop_node: ast.Block) -> ast.Block: if len(indexed_elements) == 0: return loop_node reference_element = get_common_indexed_element(indexed_elements) + index = list(reference_element.indices[0].atoms(TypedSymbol)) + assert len(index) == 1, "index expressions must only contain one symbol representing the index" + new_loop = ast.LoopOverCoordinate(loop_node, 0, 0, - reference_element.shape[0], 1, custom_loop_ctr=reference_element.indices[0]) + reference_element.shape[0], 1, custom_loop_ctr=index[0]) return ast.Block([new_loop]) -- GitLab