From 6ed38634be104f308e5ea4bcc776a33d78e448ae Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Sat, 28 Sep 2024 12:23:39 +0200 Subject: [PATCH] Allow index expression for indexed domain kernel --- src/pystencils/gpu/kernelcreation.py | 4 +++- src/pystencils/transformations.py | 4 +++- tests/test_indexed_kernels.py | 8 ++++---- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/pystencils/gpu/kernelcreation.py b/src/pystencils/gpu/kernelcreation.py index c2e6143bc..2feb8883a 100644 --- a/src/pystencils/gpu/kernelcreation.py +++ b/src/pystencils/gpu/kernelcreation.py @@ -72,9 +72,11 @@ def create_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig): if len(indexed_elements) > 0: common_indexed_element = get_common_indexed_element(indexed_elements) + index = common_indexed_element.indices[0].atoms(TypedSymbol) + assert len(index) == 1, "index expressions must only contain one symbol representing the index" indexing = indexing_creator(iteration_space=(slice(0, common_indexed_element.shape[0], 1), *iteration_space), data_layout=common_field.layout) - extended_ctrs = [common_indexed_element.indices[0], *loop_counter_symbols] + extended_ctrs = [index.pop(), *loop_counter_symbols] loop_counter_assignments = indexing.get_loop_ctr_assignments(extended_ctrs) else: indexing = indexing_creator(iteration_space=iteration_space, data_layout=common_field.layout) diff --git a/src/pystencils/transformations.py b/src/pystencils/transformations.py index 79c24d146..5374806a9 100644 --- a/src/pystencils/transformations.py +++ b/src/pystencils/transformations.py @@ -276,8 +276,10 @@ def add_outer_loop_over_indexed_elements(loop_node: ast.Block) -> ast.Block: if len(indexed_elements) == 0: return loop_node reference_element = get_common_indexed_element(indexed_elements) + index = reference_element.indices[0].atoms(TypedSymbol) + assert len(index) == 1, "index expressions must only contain one symbol representing the index" new_loop = ast.LoopOverCoordinate(loop_node, 0, 0, - reference_element.shape[0], 1, custom_loop_ctr=reference_element.indices[0]) + reference_element.shape[0], 1, custom_loop_ctr=index.pop()) return ast.Block([new_loop]) diff --git a/tests/test_indexed_kernels.py b/tests/test_indexed_kernels.py index 2c0738dcf..343b04a26 100644 --- a/tests/test_indexed_kernels.py +++ b/tests/test_indexed_kernels.py @@ -64,15 +64,15 @@ def test_indexed_domain_kernel(index_size, array_size, target, dtype): src = sp.IndexedBase(TypedSymbol(f"_data_{f.name}", dtype=const_pointer_type), shape=index_src) dst = sp.IndexedBase(TypedSymbol(f"_data_{g.name}", dtype=pointer_type), shape=index_dst) - update_rule = [ps.Assignment(FieldPointerSymbol("f", dtype, const=True), src[index]), - ps.Assignment(FieldPointerSymbol("g", dtype, const=False), dst[index]), + update_rule = [ps.Assignment(FieldPointerSymbol("f", dtype, const=True), src[index + 1]), + ps.Assignment(FieldPointerSymbol("g", dtype, const=False), dst[index + 1]), ps.Assignment(g.center, f.center)] ast = ps.create_kernel(update_rule, target=target) code = ps.get_code_str(ast) - assert f"const {dtype.c_name} * RESTRICT _data_f = (({dtype.c_name} * RESTRICT const)(_data_f[index]));" in code - assert f"{dtype.c_name} * RESTRICT _data_g = (({dtype.c_name} * RESTRICT )(_data_g[index]));" in code + assert f"const {dtype.c_name} * RESTRICT _data_f = (({dtype.c_name} * RESTRICT const)(_data_f[index + 1]));" in code + assert f"{dtype.c_name} * RESTRICT _data_g = (({dtype.c_name} * RESTRICT )(_data_g[index + 1]));" in code if target == Target.CPU: assert code.count("for") == f.spatial_dimensions + 1 -- GitLab