From 356a8343b0854f87b45f5db8237e4c3faa921a65 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sun, 30 Jun 2024 11:15:55 +0200
Subject: [PATCH] WIP: GPU thread range tests

---
 src/pystencils/backend/platforms/sycl.py      |  2 +-
 .../kernelcreation/platform/test_basic_gpu.py | 28 ------------
 .../platform/test_gpu_platforms.py            | 45 +++++++++++++++++++
 3 files changed, 46 insertions(+), 29 deletions(-)
 delete mode 100644 tests/nbackend/kernelcreation/platform/test_basic_gpu.py
 create mode 100644 tests/nbackend/kernelcreation/platform/test_gpu_platforms.py

diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py
index 0eff7db02..cf4461309 100644
--- a/src/pystencils/backend/platforms/sycl.py
+++ b/src/pystencils/backend/platforms/sycl.py
@@ -19,7 +19,7 @@ from ...config import GpuIndexingConfig
 class SyclPlatform(GenericGpu):
 
     def __init__(
-        self, ctx: KernelCreationContext, indexing_cfg: GpuIndexingConfig | None
+        self, ctx: KernelCreationContext, indexing_cfg: GpuIndexingConfig | None = None
     ):
         super().__init__(ctx)
         self._cfg = indexing_cfg if indexing_cfg is not None else GpuIndexingConfig()
diff --git a/tests/nbackend/kernelcreation/platform/test_basic_gpu.py b/tests/nbackend/kernelcreation/platform/test_basic_gpu.py
deleted file mode 100644
index 90b88dcd7..000000000
--- a/tests/nbackend/kernelcreation/platform/test_basic_gpu.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import pytest
-
-from pystencils.field import Field
-
-from pystencils.backend.kernelcreation import (
-    KernelCreationContext,
-    FullIterationSpace
-)
-
-from pystencils.backend.ast.structural import PsBlock, PsLoop, PsComment
-from pystencils.backend.ast.expressions import PsExpression
-from pystencils.backend.ast import dfs_preorder
-
-from pystencils.backend.platforms import CudaPlatform
-
-
-@pytest.mark.parametrize("layout", ["fzyx", "zyxf", "c", "f"])
-def test_loop_nest(layout):
-    ctx = KernelCreationContext()
-
-    body = PsBlock([PsComment("Loop body goes here")])
-    platform = CudaPlatform(ctx)
-
-    #   FZYX Order
-    archetype_field = Field.create_generic("fzyx_field", spatial_dimensions=3, layout=layout)
-    ispace = FullIterationSpace.create_with_ghost_layers(ctx, 0, archetype_field)
-
-    _ = platform.materialize_iteration_space(body, ispace)
diff --git a/tests/nbackend/kernelcreation/platform/test_gpu_platforms.py b/tests/nbackend/kernelcreation/platform/test_gpu_platforms.py
new file mode 100644
index 000000000..87345575a
--- /dev/null
+++ b/tests/nbackend/kernelcreation/platform/test_gpu_platforms.py
@@ -0,0 +1,45 @@
+#%%
+import pytest
+
+from pystencils.field import Field
+
+from pystencils.backend.kernelcreation import (
+    KernelCreationContext,
+    FullIterationSpace
+)
+
+from pystencils.backend.ast.structural import PsBlock, PsLoop, PsComment
+from pystencils.backend.ast.expressions import PsExpression
+from pystencils.backend.ast import dfs_preorder
+
+from pystencils.backend.platforms import CudaPlatform, SyclPlatform
+
+
+@pytest.mark.parametrize("layout", ["fzyx", "zyxf", "c", "f"])
+@pytest.mark.parametrize("platform_class", [CudaPlatform, SyclPlatform])
+def test_thread_range(platform_class, layout):
+    ctx = KernelCreationContext()
+
+    body = PsBlock([PsComment("Kernel body goes here")])
+    platform = platform_class(ctx)
+
+    dim = 3
+    archetype_field = Field.create_generic("field", spatial_dimensions=dim, layout=layout)
+    ispace = FullIterationSpace.create_with_ghost_layers(ctx, 1, archetype_field)
+
+    _, threads_range = platform.materialize_iteration_space(body, ispace)
+
+    assert threads_range.dim == dim
+    
+    loop_order = archetype_field.layout
+
+    for i in range(dim):
+        coordinate = loop_order[i]
+        dimension = ispace.dimensions[coordinate]
+        witems = threads_range.num_work_items[i]
+        desired = (dimension.stop - dimension.start) / dimension.step
+        assert witems.structurally_equal(desired)
+
+
+#%%
+test_thread_range(CudaPlatform, "fzyx")
-- 
GitLab