diff --git a/src/pystencils/codegen/gpu_indexing.py b/src/pystencils/codegen/gpu_indexing.py index c93f0f95908c0438a233abdfbd585d164e4e7f96..27d6fc817d5a9193c3faa4b170d907987fe6022e 100644 --- a/src/pystencils/codegen/gpu_indexing.py +++ b/src/pystencils/codegen/gpu_indexing.py @@ -267,6 +267,11 @@ class GpuIndexing: f" for a {rank}-dimensional kernel." ) + work_items_expr += tuple( + self._ast_factory.parse_index(1) + for _ in range(3 - rank) + ) + num_work_items = cast( _Dim3Lambda, tuple(self._kernel_factory.create_lambda(wit) for wit in work_items_expr), diff --git a/tests/kernelcreation/test_gpu.py b/tests/kernelcreation/test_gpu.py index 10b37e610cebd23c9fc961f14118aee5f24582c4..f1905b1fcb7c7406f43cfb94af2928b6f35bc3f8 100644 --- a/tests/kernelcreation/test_gpu.py +++ b/tests/kernelcreation/test_gpu.py @@ -31,7 +31,7 @@ except ImportError: @pytest.mark.parametrize("indexing_scheme", ["linear3d", "blockwise4d"]) @pytest.mark.parametrize("omit_range_check", [False, True]) @pytest.mark.parametrize("manual_grid", [False, True]) -def test_indexing_options( +def test_indexing_options_3d( indexing_scheme: str, omit_range_check: bool, manual_grid: bool ): src, dst = fields("src, dst: [3D]") @@ -76,6 +76,52 @@ def test_indexing_options( cp.testing.assert_allclose(dst_arr, expected) +@pytest.mark.parametrize("indexing_scheme", ["linear3d", "blockwise4d"]) +@pytest.mark.parametrize("omit_range_check", [False, True]) +@pytest.mark.parametrize("manual_grid", [False, True]) +def test_indexing_options_2d( + indexing_scheme: str, omit_range_check: bool, manual_grid: bool +): + src, dst = fields("src, dst: [2D]") + asm = Assignment( + dst.center(), + src[-1, 0] + + src[1, 0] + + src[0, -1] + + src[0, 1] + ) + + cfg = CreateKernelConfig(target=Target.CUDA) + cfg.gpu.indexing_scheme = indexing_scheme + cfg.gpu.omit_range_check = omit_range_check + cfg.gpu.manual_launch_grid = manual_grid + + ast = create_kernel(asm, cfg) + kernel = ast.compile() + + src_arr = cp.ones((18, 42)) + dst_arr = cp.zeros_like(src_arr) + + if manual_grid: + match indexing_scheme: + case "linear3d": + kernel.launch_config.block_size = (10, 8, 1) + kernel.launch_config.grid_size = (4, 2, 1) + case "blockwise4d": + kernel.launch_config.block_size = (40, 1, 1) + kernel.launch_config.grid_size = (16, 1, 1) + + elif indexing_scheme == "linear3d": + kernel.launch_config.block_size = (10, 8, 1) + + kernel(src=src_arr, dst=dst_arr) + + expected = cp.zeros_like(src_arr) + expected[1:-1, 1:-1].fill(4.0) + + cp.testing.assert_allclose(dst_arr, expected) + + def test_invalid_indexing_schemes(): src, dst = fields("src, dst: [4D]") asm = Assignment(src.center(0), dst.center(0))