diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 1de001643ac2610f1fbb34b5015ed457018285dd..774b9405cd04b8dc6489cd6b6ae36e4aa563f157 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -13,6 +13,7 @@ from .platform import Platform class GpuThreadsRange: + """Number of threads required by a GPU kernel, in order (x, y, z).""" @staticmethod def from_ispace(ispace: IterationSpace) -> GpuThreadsRange: @@ -41,6 +42,7 @@ class GpuThreadsRange: @property def num_work_items(self) -> tuple[PsExpression, ...]: + """Number of work items in (x, y, z)-order.""" return self._num_work_items @property @@ -49,7 +51,7 @@ class GpuThreadsRange: @staticmethod def _from_full_ispace(ispace: FullIterationSpace) -> GpuThreadsRange: - dimensions = ispace.dimensions_in_loop_order() + dimensions = ispace.dimensions_in_loop_order()[::-1] if len(dimensions) > 3: raise NotImplementedError( f"Cannot create a GPU threads range for an {len(dimensions)}-dimensional iteration space" diff --git a/tests/nbackend/kernelcreation/platform/test_gpu_platforms.py b/tests/nbackend/kernelcreation/platform/test_gpu_platforms.py index c1bfd28adc6af5ebeda37bab5e8d030d98fd9582..da2b3a5ad3a0e224bc47a5dd0fa4f16b0ccde520 100644 --- a/tests/nbackend/kernelcreation/platform/test_gpu_platforms.py +++ b/tests/nbackend/kernelcreation/platform/test_gpu_platforms.py @@ -28,10 +28,15 @@ def test_thread_range(platform_class, layout): assert threads_range.dim == dim - loop_order = archetype_field.layout + match layout: + case "fzyx" | "zyxf" | "f": + indexing_order = [0, 1, 2] + case "c": + indexing_order = [2, 1, 0] for i in range(dim): - coordinate = loop_order[i] + # Slowest to fastest coordinate + coordinate = indexing_order[i] dimension = ispace.dimensions[coordinate] witems = threads_range.num_work_items[i] desired = dimension.stop - dimension.start diff --git a/tests/nbackend/kernelcreation/test_domain_kernels.py b/tests/nbackend/kernelcreation/test_domain_kernels.py index 9a1b366384b2ac1cbaae10cfd6a299b992b00efc..5850c94d79b8eb5293c5853f65bf67c91cfd452d 100644 --- a/tests/nbackend/kernelcreation/test_domain_kernels.py +++ b/tests/nbackend/kernelcreation/test_domain_kernels.py @@ -30,7 +30,7 @@ def test_filter_kernel(target): ast = create_kernel(asms, gen_config) kernel = ast.compile() - src_arr = xp.ones((42, 42)) + src_arr = xp.ones((42, 31)) dst_arr = xp.zeros_like(src_arr) kernel(src=src_arr, dst=dst_arr, weight=2.0) @@ -55,7 +55,7 @@ def test_filter_kernel_fixedsize(target): [1, 1, 1] ] - src_arr = xp.ones((42, 42)) + src_arr = xp.ones((42, 31)) dst_arr = xp.zeros_like(src_arr) src = Field.create_from_numpy_array("src", src_arr)