Skip to content
Snippets Groups Projects
Commit f2045b96 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fix dimension order in threads range

parent 1080f7db
No related branches found
No related tags found
1 merge request!384Fundamental GPU Support
Pipeline #67489 passed
...@@ -13,6 +13,7 @@ from .platform import Platform ...@@ -13,6 +13,7 @@ from .platform import Platform
class GpuThreadsRange: class GpuThreadsRange:
"""Number of threads required by a GPU kernel, in order (x, y, z)."""
@staticmethod @staticmethod
def from_ispace(ispace: IterationSpace) -> GpuThreadsRange: def from_ispace(ispace: IterationSpace) -> GpuThreadsRange:
...@@ -41,6 +42,7 @@ class GpuThreadsRange: ...@@ -41,6 +42,7 @@ class GpuThreadsRange:
@property @property
def num_work_items(self) -> tuple[PsExpression, ...]: def num_work_items(self) -> tuple[PsExpression, ...]:
"""Number of work items in (x, y, z)-order."""
return self._num_work_items return self._num_work_items
@property @property
...@@ -49,7 +51,7 @@ class GpuThreadsRange: ...@@ -49,7 +51,7 @@ class GpuThreadsRange:
@staticmethod @staticmethod
def _from_full_ispace(ispace: FullIterationSpace) -> GpuThreadsRange: def _from_full_ispace(ispace: FullIterationSpace) -> GpuThreadsRange:
dimensions = ispace.dimensions_in_loop_order() dimensions = ispace.dimensions_in_loop_order()[::-1]
if len(dimensions) > 3: if len(dimensions) > 3:
raise NotImplementedError( raise NotImplementedError(
f"Cannot create a GPU threads range for an {len(dimensions)}-dimensional iteration space" f"Cannot create a GPU threads range for an {len(dimensions)}-dimensional iteration space"
......
...@@ -28,10 +28,15 @@ def test_thread_range(platform_class, layout): ...@@ -28,10 +28,15 @@ def test_thread_range(platform_class, layout):
assert threads_range.dim == dim assert threads_range.dim == dim
loop_order = archetype_field.layout match layout:
case "fzyx" | "zyxf" | "f":
indexing_order = [0, 1, 2]
case "c":
indexing_order = [2, 1, 0]
for i in range(dim): for i in range(dim):
coordinate = loop_order[i] # Slowest to fastest coordinate
coordinate = indexing_order[i]
dimension = ispace.dimensions[coordinate] dimension = ispace.dimensions[coordinate]
witems = threads_range.num_work_items[i] witems = threads_range.num_work_items[i]
desired = dimension.stop - dimension.start desired = dimension.stop - dimension.start
......
...@@ -30,7 +30,7 @@ def test_filter_kernel(target): ...@@ -30,7 +30,7 @@ def test_filter_kernel(target):
ast = create_kernel(asms, gen_config) ast = create_kernel(asms, gen_config)
kernel = ast.compile() kernel = ast.compile()
src_arr = xp.ones((42, 42)) src_arr = xp.ones((42, 31))
dst_arr = xp.zeros_like(src_arr) dst_arr = xp.zeros_like(src_arr)
kernel(src=src_arr, dst=dst_arr, weight=2.0) kernel(src=src_arr, dst=dst_arr, weight=2.0)
...@@ -55,7 +55,7 @@ def test_filter_kernel_fixedsize(target): ...@@ -55,7 +55,7 @@ def test_filter_kernel_fixedsize(target):
[1, 1, 1] [1, 1, 1]
] ]
src_arr = xp.ones((42, 42)) src_arr = xp.ones((42, 31))
dst_arr = xp.zeros_like(src_arr) dst_arr = xp.zeros_like(src_arr)
src = Field.create_from_numpy_array("src", src_arr) src = Field.create_from_numpy_array("src", src_arr)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment