From ff396a6a074e345925692586f239176d395c867c Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 28 Jun 2024 13:18:41 +0200
Subject: [PATCH] Improve support for GPU Thread Indexing

 - Introduce GpuThreadRange and GpuKernelFunction to export the launch space expected for the kernel
 - Refactor GPU platforms to create and export a thread range
---
 .../backend/kernelcreation/iteration_space.py |  18 ++-
 src/pystencils/backend/kernelfunction.py      | 102 ++++++++++++++-
 src/pystencils/backend/platforms/__init__.py  |   5 +-
 src/pystencils/backend/platforms/cuda.py      |  88 +++++++++++++
 .../backend/platforms/generic_gpu.py          | 118 +++++++-----------
 src/pystencils/backend/platforms/platform.py  |   3 +-
 src/pystencils/backend/platforms/sycl.py      |  58 +++++----
 src/pystencils/kernelcreation.py              |  79 ++++--------
 .../kernelcreation/platform/test_basic_gpu.py |   4 +-
 9 files changed, 319 insertions(+), 156 deletions(-)
 create mode 100644 src/pystencils/backend/platforms/cuda.py

diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py
index d50e4b50f..38aca4efe 100644
--- a/src/pystencils/backend/kernelcreation/iteration_space.py
+++ b/src/pystencils/backend/kernelcreation/iteration_space.py
@@ -208,14 +208,28 @@ class FullIterationSpace(IterationSpace):
     @property
     def archetype_field(self) -> Field | None:
         return self._archetype_field
+    
+    def dimensions_in_loop_order(self) -> Sequence[FullIterationSpace.Dimension]:
+        """Return the dimensions of this iteration space ordered from the fastest to the slowest coordinate.
 
-    def actual_iterations(self, dimension: int | None = None) -> PsExpression:
+        If an archetype field is specified, the field layout is used to determine the ideal loop order;
+        otherwise, the dimensions are returned as they are
+        """
+        if self._archetype_field is not None:
+            return [self._dimensions[i] for i in self._archetype_field.layout]
+        else:
+            return self._dimensions
+
+    def actual_iterations(self, dimension: int | FullIterationSpace.Dimension | None = None) -> PsExpression:
         if dimension is None:
             return reduce(
                 mul, (self.actual_iterations(d) for d in range(len(self.dimensions)))
             )
         else:
-            dim = self.dimensions[dimension]
+            if isinstance(dimension, FullIterationSpace.Dimension):
+                dim = dimension
+            else:
+                dim = self.dimensions[dimension]
             one = PsConstantExpr(PsConstant(1, self._ctx.index_dtype))
             return one + (dim.stop - dim.start - one) / dim.step
 
diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py
index 97837492b..dbc1b6519 100644
--- a/src/pystencils/backend/kernelfunction.py
+++ b/src/pystencils/backend/kernelfunction.py
@@ -1,9 +1,14 @@
 from __future__ import annotations
 
 from abc import ABC
-from typing import Callable, Sequence
+from typing import Callable, Sequence, Iterable
 
 from .ast.structural import PsBlock
+from .ast.analysis import collect_required_headers, collect_undefined_symbols
+from .arrays import PsArrayShapeSymbol, PsArrayStrideSymbol, PsArrayBasePointer
+from .symbols import PsSymbol
+from .kernelcreation.context import KernelCreationContext
+from .platforms import Platform, GpuThreadsRange
 
 from .constraints import KernelParamsConstraint
 from ..types import PsType
@@ -161,3 +166,98 @@ class KernelFunction:
 
     def compile(self) -> Callable[..., None]:
         return self._jit.compile(self)
+
+
+def create_cpu_kernel_function(
+    ctx: KernelCreationContext,
+    platform: Platform,
+    body: PsBlock,
+    function_name: str,
+    target_spec: Target,
+    jit: JitBase,
+):
+    undef_symbols = collect_undefined_symbols(body)
+
+    params = _get_function_params(ctx, undef_symbols)
+    req_headers = _get_headers(ctx, platform, body)
+
+    return KernelFunction(
+        body, target_spec, function_name, params, req_headers, ctx.constraints, jit
+    )
+
+
+class GpuKernelFunction(KernelFunction):
+    def __init__(
+        self,
+        body: PsBlock,
+        threads_range: GpuThreadsRange,
+        target: Target,
+        name: str,
+        parameters: Sequence[KernelParameter],
+        required_headers: set[str],
+        constraints: Sequence[KernelParamsConstraint],
+        jit: JitBase = no_jit,
+    ):
+        super().__init__(
+            body, target, name, parameters, required_headers, constraints, jit
+        )
+        self._threads_range = threads_range
+
+    @property
+    def threads_range(self) -> GpuThreadsRange:
+        return self._threads_range
+
+
+def create_gpu_kernel_function(
+    ctx: KernelCreationContext,
+    platform: Platform,
+    body: PsBlock,
+    threads_range: GpuThreadsRange,
+    function_name: str,
+    target_spec: Target,
+    jit: JitBase,
+):
+    undef_symbols = collect_undefined_symbols(body)
+    for threads in threads_range.num_work_items:
+        undef_symbols |= collect_undefined_symbols(threads)
+
+    params = _get_function_params(ctx, undef_symbols)
+    req_headers = _get_headers(ctx, platform, body)
+
+    return GpuKernelFunction(
+        body,
+        threads_range,
+        target_spec,
+        function_name,
+        params,
+        req_headers,
+        ctx.constraints,
+        jit,
+    )
+
+
+def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]):
+    params: list[KernelParameter] = []
+    for symb in symbols:
+        match symb:
+            case PsArrayShapeSymbol(name, _, arr, coord):
+                field = ctx.find_field(arr.name)
+                params.append(FieldShapeParam(name, symb.get_dtype(), field, coord))
+            case PsArrayStrideSymbol(name, _, arr, coord):
+                field = ctx.find_field(arr.name)
+                params.append(FieldStrideParam(name, symb.get_dtype(), field, coord))
+            case PsArrayBasePointer(name, _, arr):
+                field = ctx.find_field(arr.name)
+                params.append(FieldPointerParam(name, symb.get_dtype(), field))
+            case PsSymbol(name, _):
+                params.append(KernelParameter(name, symb.get_dtype()))
+
+    params.sort(key=lambda p: p.name)
+    return params
+
+
+def _get_headers(ctx: KernelCreationContext, platform: Platform, body: PsBlock):
+    req_headers = collect_required_headers(body)
+    req_headers |= platform.required_headers
+    req_headers |= ctx.required_headers
+    return req_headers
diff --git a/src/pystencils/backend/platforms/__init__.py b/src/pystencils/backend/platforms/__init__.py
index af4d88e79..9332453c6 100644
--- a/src/pystencils/backend/platforms/__init__.py
+++ b/src/pystencils/backend/platforms/__init__.py
@@ -1,6 +1,7 @@
 from .platform import Platform
 from .generic_cpu import GenericCpu, GenericVectorCpu
-from .generic_gpu import GenericGpu
+from .generic_gpu import GenericGpu, GpuThreadsRange
+from .cuda import CudaPlatform
 from .x86 import X86VectorCpu, X86VectorArch
 from .sycl import SyclPlatform
 
@@ -11,5 +12,7 @@ __all__ = [
     "X86VectorCpu",
     "X86VectorArch",
     "GenericGpu",
+    "GpuThreadsRange",
+    "CudaPlatform",
     "SyclPlatform",
 ]
diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py
new file mode 100644
index 000000000..eb1576190
--- /dev/null
+++ b/src/pystencils/backend/platforms/cuda.py
@@ -0,0 +1,88 @@
+from pystencils.backend.functions import CFunction, PsMathFunction
+from pystencils.types.types import PsType
+from .platform import Platform
+
+from ..kernelcreation.iteration_space import (
+    IterationSpace,
+    FullIterationSpace,
+    # SparseIterationSpace,
+)
+
+from ..ast.structural import PsBlock, PsConditional
+from ..ast.expressions import (
+    PsExpression,
+    PsLiteralExpr,
+    PsAdd,
+)
+from ..ast.expressions import PsLt, PsAnd
+from ...types import PsSignedIntegerType
+from ..literals import PsLiteral
+
+int32 = PsSignedIntegerType(width=32, const=False)
+
+BLOCK_IDX = [
+    PsLiteralExpr(PsLiteral(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z")
+]
+THREAD_IDX = [
+    PsLiteralExpr(PsLiteral(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z")
+]
+BLOCK_DIM = [
+    PsLiteralExpr(PsLiteral(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z")
+]
+GRID_DIM = [
+    PsLiteralExpr(PsLiteral(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z")
+]
+
+
+class CudaPlatform(Platform):
+
+    @property
+    def required_headers(self) -> set[str]:
+        return {"gpu_defines.h"}
+
+    def materialize_iteration_space(
+        self, body: PsBlock, ispace: IterationSpace
+    ) -> PsBlock:
+        if isinstance(ispace, FullIterationSpace):
+            return self._guard_full_iteration_space(body, ispace)
+        else:
+            assert False, "unreachable code"
+
+    def cuda_indices(self, dim):
+        block_size = BLOCK_DIM
+        indices = [
+            block_index * bs + thread_idx
+            for block_index, bs, thread_idx in zip(BLOCK_IDX, block_size, THREAD_IDX)
+        ]
+
+        return indices[:dim]
+
+    def select_function(
+        self, math_function: PsMathFunction, dtype: PsType
+    ) -> CFunction:
+        raise NotImplementedError()
+
+    #   Internals
+    def _guard_full_iteration_space(
+        self, body: PsBlock, ispace: FullIterationSpace
+    ) -> PsBlock:
+
+        dimensions = ispace.dimensions
+
+        #   Determine loop order by permuting dimensions
+        archetype_field = ispace.archetype_field
+        if archetype_field is not None:
+            loop_order = archetype_field.layout
+            dimensions = [dimensions[coordinate] for coordinate in loop_order]
+
+        start = [
+            PsAdd(c, d.start)
+            for c, d in zip(self.cuda_indices(len(dimensions)), dimensions[::-1])
+        ]
+        conditions = [PsLt(c, d.stop) for c, d in zip(start, dimensions[::-1])]
+
+        condition: PsExpression = conditions[0]
+        for c in conditions[1:]:
+            condition = PsAnd(condition, c)
+
+        return PsBlock([PsConditional(condition, body)])
diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py
index 6e860d32b..1de001643 100644
--- a/src/pystencils/backend/platforms/generic_gpu.py
+++ b/src/pystencils/backend/platforms/generic_gpu.py
@@ -1,88 +1,66 @@
-from pystencils.backend.functions import CFunction, PsMathFunction
-from pystencils.types.types import PsType
-from .platform import Platform
+from __future__ import annotations
+from typing import Sequence
+from abc import abstractmethod
 
+from ..ast.expressions import PsExpression
+from ..ast.structural import PsBlock
 from ..kernelcreation.iteration_space import (
     IterationSpace,
     FullIterationSpace,
-    # SparseIterationSpace,
-)
-
-from ..ast.structural import PsBlock, PsConditional
-from ..ast.expressions import (
-    PsExpression,
-    PsLiteralExpr,
-    PsAdd,
+    SparseIterationSpace,
 )
-from ..ast.expressions import PsLt, PsAnd
-from ...types import PsSignedIntegerType
-from ..literals import PsLiteral
-
-int32 = PsSignedIntegerType(width=32, const=False)
-
-BLOCK_IDX = [
-    PsLiteralExpr(PsLiteral(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z")
-]
-THREAD_IDX = [
-    PsLiteralExpr(PsLiteral(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z")
-]
-BLOCK_DIM = [
-    PsLiteralExpr(PsLiteral(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z")
-]
-GRID_DIM = [
-    PsLiteralExpr(PsLiteral(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z")
-]
+from .platform import Platform
 
 
-class GenericGpu(Platform):
+class GpuThreadsRange:
 
-    @property
-    def required_headers(self) -> set[str]:
-        return {"gpu_defines.h"}
-
-    def materialize_iteration_space(
-        self, body: PsBlock, ispace: IterationSpace
-    ) -> PsBlock:
+    @staticmethod
+    def from_ispace(ispace: IterationSpace) -> GpuThreadsRange:
         if isinstance(ispace, FullIterationSpace):
-            return self._guard_full_iteration_space(body, ispace)
+            return GpuThreadsRange._from_full_ispace(ispace)
+        elif isinstance(ispace, SparseIterationSpace):
+            work_items = (PsExpression.make(ispace.index_list.shape[0]),)
+            return GpuThreadsRange(work_items)
         else:
-            assert False, "unreachable code"
+            assert False
 
-    def cuda_indices(self, dim):
-        block_size = BLOCK_DIM
-        indices = [
-            block_index * bs + thread_idx
-            for block_index, bs, thread_idx in zip(BLOCK_IDX, block_size, THREAD_IDX)
-        ]
+    def __init__(
+        self,
+        num_work_items: Sequence[PsExpression],
+    ):
+        self._dim = len(num_work_items)
+        self._num_work_items = tuple(num_work_items)
 
-        return indices[:dim]
+    # @property
+    # def grid_size(self) -> tuple[PsExpression, ...]:
+    #     return self._grid_size
 
-    def select_function(
-        self, math_function: PsMathFunction, dtype: PsType
-    ) -> CFunction:
-        raise NotImplementedError()
+    # @property
+    # def block_size(self) -> tuple[PsExpression, ...]:
+    #     return self._block_size
 
-    #   Internals
-    def _guard_full_iteration_space(
-        self, body: PsBlock, ispace: FullIterationSpace
-    ) -> PsBlock:
-
-        dimensions = ispace.dimensions
+    @property
+    def num_work_items(self) -> tuple[PsExpression, ...]:
+        return self._num_work_items
 
-        #   Determine loop order by permuting dimensions
-        archetype_field = ispace.archetype_field
-        if archetype_field is not None:
-            loop_order = archetype_field.layout
-            dimensions = [dimensions[coordinate] for coordinate in loop_order]
+    @property
+    def dim(self) -> int:
+        return self._dim
 
-        start = [
-            PsAdd(c, d.start)
-            for c, d in zip(self.cuda_indices(len(dimensions)), dimensions[::-1])
-        ]
-        conditions = [PsLt(c, d.stop) for c, d in zip(start, dimensions[::-1])]
+    @staticmethod
+    def _from_full_ispace(ispace: FullIterationSpace) -> GpuThreadsRange:
+        dimensions = ispace.dimensions_in_loop_order()
+        if len(dimensions) > 3:
+            raise NotImplementedError(
+                f"Cannot create a GPU threads range for an {len(dimensions)}-dimensional iteration space"
+            )
+        work_items = [ispace.actual_iterations(dim) for dim in dimensions]
+        return GpuThreadsRange(work_items)
 
-        condition: PsExpression = conditions[0]
-        for c in conditions[1:]:
-            condition = PsAnd(condition, c)
 
-        return PsBlock([PsConditional(condition, body)])
+class GenericGpu(Platform):
+    @abstractmethod
+    def materialize_iteration_space(
+        self, block: PsBlock, ispace: IterationSpace
+    ) -> tuple[PsBlock, GpuThreadsRange]:
+        pass
diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py
index 2c718ae5f..27a3e7c02 100644
--- a/src/pystencils/backend/platforms/platform.py
+++ b/src/pystencils/backend/platforms/platform.py
@@ -1,4 +1,5 @@
 from abc import ABC, abstractmethod
+from typing import Any
 
 from ..ast.structural import PsBlock
 from ..functions import PsMathFunction, CFunction
@@ -28,7 +29,7 @@ class Platform(ABC):
     @abstractmethod
     def materialize_iteration_space(
         self, block: PsBlock, ispace: IterationSpace
-    ) -> PsBlock:
+    ) -> PsBlock | tuple[PsBlock, Any]:
         pass
 
     @abstractmethod
diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py
index 1e4cbcaa0..9ee33cc4e 100644
--- a/src/pystencils/backend/platforms/sycl.py
+++ b/src/pystencils/backend/platforms/sycl.py
@@ -1,29 +1,26 @@
 from ..functions import CFunction, PsMathFunction, MathFunctions
-from ..ast.structural import PsBlock
 from ..kernelcreation.iteration_space import (
     IterationSpace,
     FullIterationSpace,
     SparseIterationSpace,
 )
-from ..ast.structural import PsDeclaration
-from ..ast.expressions import (
-    PsExpression,
-    PsSymbolExpr,
-    PsSubscript,
-)
+from ..ast.structural import PsDeclaration, PsBlock, PsConditional
+from ..ast.expressions import PsExpression, PsSymbolExpr, PsSubscript, PsLt, PsAnd
 from ..extensions.cpp import CppMethodCall
 
 from ..kernelcreation.context import KernelCreationContext
 from ..constants import PsConstant
-from .platform import Platform
+from .generic_gpu import GenericGpu, GpuThreadsRange
 from ..exceptions import MaterializationError
 from ...types import PsType, PsCustomType, PsIeeeFloatType, constify
 from ...config import GpuIndexingConfig
 
 
-class SyclPlatform(Platform):
+class SyclPlatform(GenericGpu):
 
-    def __init__(self, ctx: KernelCreationContext, indexing_cfg: GpuIndexingConfig | None):
+    def __init__(
+        self, ctx: KernelCreationContext, indexing_cfg: GpuIndexingConfig | None
+    ):
         super().__init__(ctx)
         self._cfg = indexing_cfg if indexing_cfg is not None else GpuIndexingConfig()
 
@@ -33,7 +30,7 @@ class SyclPlatform(Platform):
 
     def materialize_iteration_space(
         self, body: PsBlock, ispace: IterationSpace
-    ) -> PsBlock:
+    ) -> tuple[PsBlock, GpuThreadsRange]:
         if isinstance(ispace, FullIterationSpace):
             return self._prepend_dense_translation(body, ispace)
         elif isinstance(ispace, SparseIterationSpace):
@@ -65,22 +62,17 @@ class SyclPlatform(Platform):
 
     def _prepend_dense_translation(
         self, body: PsBlock, ispace: FullIterationSpace
-    ) -> PsBlock:
+    ) -> tuple[PsBlock, GpuThreadsRange]:
         rank = ispace.rank
         id_type = self._id_type(rank)
         id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type))
         id_decl = self._id_declaration(rank, id_symbol)
 
-        #   Determine loop order by permuting dimensions
-        archetype_field = ispace.archetype_field
-        
-        if archetype_field is not None:
-            loop_order = archetype_field.layout
-            dimensions = [ispace.dimensions[coordinate] for coordinate in loop_order]
-        else:
-            dimensions = ispace.dimensions
+        dimensions = ispace.dimensions_in_loop_order()
+        launch_config = GpuThreadsRange.from_ispace(ispace)
 
-        unpackings = [id_decl]
+        indexing_decls = [id_decl]
+        conds = []
         for i, dim in enumerate(dimensions[::-1]):
             coord = PsExpression.make(PsConstant(i, self._ctx.index_dtype))
             work_item_idx = PsSubscript(id_symbol, coord)
@@ -89,14 +81,26 @@ class SyclPlatform(Platform):
             work_item_idx.dtype = dim.counter.get_dtype()
 
             ctr = PsExpression.make(dim.counter)
-            unpackings.append(PsDeclaration(ctr, dim.start + work_item_idx * dim.step))
+            indexing_decls.append(
+                PsDeclaration(ctr, dim.start + work_item_idx * dim.step)
+            )
+            if not self._cfg.omit_range_check:
+                conds.append(PsLt(ctr, dim.stop))
+
+        if conds:
+            condition: PsExpression = conds[0]
+            for cond in conds:
+                condition = PsAnd(condition, cond)
+            ast = PsBlock(indexing_decls + [PsConditional(condition, body)])
+        else:
+            body.statements = indexing_decls + body.statements
+            ast = body
 
-        body.statements = unpackings + body.statements
-        return body
+        return ast, launch_config
 
     def _prepend_sparse_translation(
         self, body: PsBlock, ispace: SparseIterationSpace
-    ) -> PsBlock:
+    ) -> tuple[PsBlock, GpuThreadsRange]:
         id_type = PsCustomType("sycl::id< 1 >", const=True)
         id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type))
 
@@ -110,14 +114,14 @@ class SyclPlatform(Platform):
         unpacking = PsDeclaration(ctr, subscript)
         body.statements = [unpacking] + body.statements
 
-        return body
+        return body, GpuThreadsRange.from_ispace(ispace)
 
     def _item_type(self, rank: int):
         if not self._cfg.sycl_automatic_block_size:
             return PsCustomType(f"sycl::nd_item< {rank} >", const=True)
         else:
             return PsCustomType(f"sycl::item< {rank} >", const=True)
-        
+
     def _id_type(self, rank: int):
         return PsCustomType(f"sycl::id< {rank} >", const=True)
 
diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py
index 90fcf73d3..62f5d2968 100644
--- a/src/pystencils/kernelcreation.py
+++ b/src/pystencils/kernelcreation.py
@@ -2,35 +2,29 @@ from typing import cast
 
 from .enums import Target
 from .config import CreateKernelConfig
-from .backend import (
-    KernelFunction,
-    KernelParameter,
-    FieldShapeParam,
-    FieldStrideParam,
-    FieldPointerParam,
-)
-from .backend.symbols import PsSymbol
-from .backend.jit import JitBase
+from .backend import KernelFunction
 from .backend.ast.structural import PsBlock
-from .backend.arrays import PsArrayShapeSymbol, PsArrayStrideSymbol, PsArrayBasePointer
 from .backend.kernelcreation import (
     KernelCreationContext,
     KernelAnalysis,
     FreezeExpressions,
     Typifier,
 )
-from .backend.platforms import Platform
 from .backend.kernelcreation.iteration_space import (
     create_sparse_iteration_space,
     create_full_iteration_space,
 )
 
-from .backend.ast.analysis import collect_required_headers, collect_undefined_symbols
+
 from .backend.transformations import (
     EliminateConstants,
     EraseAnonymousStructTypes,
     SelectFunctions,
 )
+from .backend.kernelfunction import (
+    create_cpu_kernel_function,
+    create_gpu_kernel_function,
+)
 
 from .sympyextensions import AssignmentCollection, Assignment
 
@@ -91,15 +85,18 @@ def create_kernel(
             from .backend.platforms import GenericCpu
 
             platform = GenericCpu(ctx)
+            kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
         case Target.SYCL:
             from .backend.platforms import SyclPlatform
+
             platform = SyclPlatform(ctx, config.gpu_indexing)
+            kernel_ast, gpu_threads = platform.materialize_iteration_space(
+                kernel_body, ispace
+            )
         case _:
             #   TODO: CUDA/HIP platform
             raise NotImplementedError("Target platform not implemented")
 
-    kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
-
     #   Simplifying transformations
     elim_constants = EliminateConstants(ctx, extract_constant_exprs=True)
     kernel_ast = cast(PsBlock, elim_constants(kernel_ast))
@@ -108,6 +105,8 @@ def create_kernel(
     if config.target.is_cpu():
         from .backend.kernelcreation import optimize_cpu
 
+        assert isinstance(platform, GenericCpu)
+
         kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim)
 
     erase_anons = EraseAnonymousStructTypes(ctx)
@@ -117,42 +116,18 @@ def create_kernel(
     kernel_ast = cast(PsBlock, select_functions(kernel_ast))
 
     assert config.jit is not None
-    return create_kernel_function(
-        ctx, platform, kernel_ast, config.function_name, config.target, config.jit
-    )
 
-
-def create_kernel_function(
-    ctx: KernelCreationContext,
-    platform: Platform,
-    body: PsBlock,
-    function_name: str,
-    target_spec: Target,
-    jit: JitBase,
-):
-    undef_symbols = collect_undefined_symbols(body)
-
-    params = []
-    for symb in undef_symbols:
-        match symb:
-            case PsArrayShapeSymbol(name, _, arr, coord):
-                field = ctx.find_field(arr.name)
-                params.append(FieldShapeParam(name, symb.get_dtype(), field, coord))
-            case PsArrayStrideSymbol(name, _, arr, coord):
-                field = ctx.find_field(arr.name)
-                params.append(FieldStrideParam(name, symb.get_dtype(), field, coord))
-            case PsArrayBasePointer(name, _, arr):
-                field = ctx.find_field(arr.name)
-                params.append(FieldPointerParam(name, symb.get_dtype(), field))
-            case PsSymbol(name, _):
-                params.append(KernelParameter(name, symb.get_dtype()))
-
-    params.sort(key=lambda p: p.name)
-
-    req_headers = collect_required_headers(body)
-    req_headers |= platform.required_headers
-    req_headers |= ctx.required_headers
-
-    return KernelFunction(
-        body, target_spec, function_name, params, req_headers, ctx.constraints, jit
-    )
+    if config.target.is_cpu():
+        return create_cpu_kernel_function(
+            ctx, platform, kernel_ast, config.function_name, config.target, config.jit
+        )
+    else:
+        return create_gpu_kernel_function(
+            ctx,
+            platform,
+            kernel_ast,
+            gpu_threads,
+            config.function_name,
+            config.target,
+            config.jit,
+        )
diff --git a/tests/nbackend/kernelcreation/platform/test_basic_gpu.py b/tests/nbackend/kernelcreation/platform/test_basic_gpu.py
index e47f38e4d..1dc2cdb19 100644
--- a/tests/nbackend/kernelcreation/platform/test_basic_gpu.py
+++ b/tests/nbackend/kernelcreation/platform/test_basic_gpu.py
@@ -11,7 +11,7 @@ from pystencils.backend.ast.structural import PsBlock, PsLoop, PsComment
 from pystencils.backend.ast.expressions import PsExpression
 from pystencils.backend.ast import dfs_preorder
 
-from pystencils.backend.platforms import GenericGpu
+from pystencils.backend.platforms import CudaPlatform
 
 
 @pytest.mark.parametrize("layout", ["fzyx", "zyxf", "c", "f"])
@@ -19,7 +19,7 @@ def test_loop_nest(layout):
     ctx = KernelCreationContext()
 
     body = PsBlock([PsComment("Loop body goes here")])
-    platform = GenericGpu(ctx)
+    platform = CudaPlatform(ctx)
 
     #   FZYX Order
     archetype_field = Field.create_generic("fzyx_field", spatial_dimensions=3, layout=layout)
-- 
GitLab