diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py
index 1d716fa9676db4a2dda8f5241990461cab48d835..31d8ea192269a9a9947457814ff5e58d63f61c14 100644
--- a/src/pystencils/backend/ast/structural.py
+++ b/src/pystencils/backend/ast/structural.py
@@ -1,5 +1,7 @@
 from __future__ import annotations
-from typing import Sequence, cast
+
+from abc import ABC, abstractmethod
+from typing import Iterable, Sequence, cast
 from types import NoneType
 
 from .astnode import PsAstNode, PsLeafMixIn
@@ -9,10 +11,35 @@ from ..memory import PsSymbol
 from .util import failing_cast
 
 
-class PsBlock(PsAstNode):
+class PsStructuralNode(PsAstNode, ABC):
+    """Base class for structural nodes in the pystencils AST.
+
+    This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from.
+    """
+
+    def clone(self):
+        """Clone this structure node.
+
+        .. note::
+            Subclasses of `PsStructuralNode` should not override this method,
+            but implement `_clone_structural` instead.
+            That implementation shall call `clone` on any of its children.
+        """
+        return self._clone_structural()
+
+    @abstractmethod
+    def _clone_structural(self) -> PsStructuralNode:
+        """Implementation of structural node cloning.
+
+        :meta public:
+        """
+        pass
+
+
+class PsBlock(PsStructuralNode):
     __match_args__ = ("statements",)
 
-    def __init__(self, cs: Sequence[PsAstNode]):
+    def __init__(self, cs: Iterable[PsStructuralNode]):
         self._statements = list(cs)
 
     @property
@@ -21,23 +48,23 @@ class PsBlock(PsAstNode):
 
     @children.setter
     def children(self, cs: Sequence[PsAstNode]):
-        self._statements = list(cs)
+        self._statements = list([failing_cast(PsStructuralNode, c) for c in cs])
 
     def get_children(self) -> tuple[PsAstNode, ...]:
         return tuple(self._statements)
 
     def set_child(self, idx: int, c: PsAstNode):
-        self._statements[idx] = c
+        self._statements[idx] = failing_cast(PsStructuralNode, c)
 
-    def clone(self) -> PsBlock:
-        return PsBlock([stmt.clone() for stmt in self._statements])
+    def _clone_structural(self) -> PsBlock:
+        return PsBlock([stmt._clone_structural() for stmt in self._statements])
 
     @property
-    def statements(self) -> list[PsAstNode]:
+    def statements(self) -> list[PsStructuralNode]:
         return self._statements
 
     @statements.setter
-    def statements(self, stm: Sequence[PsAstNode]):
+    def statements(self, stm: Sequence[PsStructuralNode]):
         self._statements = list(stm)
 
     def __repr__(self) -> str:
@@ -45,7 +72,7 @@ class PsBlock(PsAstNode):
         return f"PsBlock( {contents} )"
 
 
-class PsStatement(PsAstNode):
+class PsStatement(PsStructuralNode):
     __match_args__ = ("expression",)
 
     def __init__(self, expr: PsExpression):
@@ -59,7 +86,7 @@ class PsStatement(PsAstNode):
     def expression(self, expr: PsExpression):
         self._expression = expr
 
-    def clone(self) -> PsStatement:
+    def _clone_structural(self) -> PsStatement:
         return PsStatement(self._expression.clone())
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -71,7 +98,7 @@ class PsStatement(PsAstNode):
         self._expression = failing_cast(PsExpression, c)
 
 
-class PsAssignment(PsAstNode):
+class PsAssignment(PsStructuralNode):
     __match_args__ = (
         "lhs",
         "rhs",
@@ -101,7 +128,7 @@ class PsAssignment(PsAstNode):
     def rhs(self, expr: PsExpression):
         self._rhs = expr
 
-    def clone(self) -> PsAssignment:
+    def _clone_structural(self) -> PsAssignment:
         return PsAssignment(self._lhs.clone(), self._rhs.clone())
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -141,7 +168,7 @@ class PsDeclaration(PsAssignment):
     def declared_symbol(self) -> PsSymbol:
         return cast(PsSymbolExpr, self._lhs).symbol
 
-    def clone(self) -> PsDeclaration:
+    def _clone_structural(self) -> PsDeclaration:
         return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone())
 
     def set_child(self, idx: int, c: PsAstNode):
@@ -157,7 +184,7 @@ class PsDeclaration(PsAssignment):
         return f"PsDeclaration({repr(self._lhs)}, {repr(self._rhs)})"
 
 
-class PsLoop(PsAstNode):
+class PsLoop(PsStructuralNode):
     __match_args__ = ("counter", "start", "stop", "step", "body")
 
     def __init__(
@@ -214,13 +241,13 @@ class PsLoop(PsAstNode):
     def body(self, block: PsBlock):
         self._body = block
 
-    def clone(self) -> PsLoop:
+    def _clone_structural(self) -> PsLoop:
         return PsLoop(
             self._ctr.clone(),
             self._start.clone(),
             self._stop.clone(),
             self._step.clone(),
-            self._body.clone(),
+            self._body._clone_structural(),
         )
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -243,7 +270,7 @@ class PsLoop(PsAstNode):
                 assert False, "unreachable code"
 
 
-class PsConditional(PsAstNode):
+class PsConditional(PsStructuralNode):
     """Conditional branch"""
 
     __match_args__ = ("condition", "branch_true", "branch_false")
@@ -282,11 +309,11 @@ class PsConditional(PsAstNode):
     def branch_false(self, block: PsBlock | None):
         self._branch_false = block
 
-    def clone(self) -> PsConditional:
+    def _clone_structural(self) -> PsConditional:
         return PsConditional(
             self._condition.clone(),
-            self._branch_true.clone(),
-            self._branch_false.clone() if self._branch_false is not None else None,
+            self._branch_true._clone_structural(),
+            self._branch_false._clone_structural() if self._branch_false is not None else None,
         )
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -317,7 +344,7 @@ class PsEmptyLeafMixIn:
     pass
 
 
-class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
+class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode):
     """A C/C++ preprocessor pragma.
 
     Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``.
@@ -335,7 +362,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
     def text(self) -> str:
         return self._text
 
-    def clone(self) -> PsPragma:
+    def _clone_structural(self) -> PsPragma:
         return PsPragma(self.text)
 
     def structurally_equal(self, other: PsAstNode) -> bool:
@@ -345,7 +372,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
         return self._text == other._text
 
 
-class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
+class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode):
     __match_args__ = ("lines",)
 
     def __init__(self, text: str) -> None:
@@ -360,7 +387,7 @@ class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
     def lines(self) -> tuple[str, ...]:
         return self._lines
 
-    def clone(self) -> PsComment:
+    def _clone_structural(self) -> PsComment:
         return PsComment(self._text)
 
     def structurally_equal(self, other: PsAstNode) -> bool:
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index ce65cd85ded0c04fd7fb7858b88610989eb6a9d0..df6bfbd1f160ceec819d8d3f1923c43c145e88b3 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -28,6 +28,7 @@ from ..ast.structural import (
     PsDeclaration,
     PsExpression,
     PsSymbolExpr,
+    PsStructuralNode,
 )
 from ..ast.expressions import (
     PsBufferAcc,
@@ -109,7 +110,7 @@ class FreezeExpressions:
 
     def __call__(self, obj: AssignmentCollection | sp.Basic) -> PsAstNode:
         if isinstance(obj, AssignmentCollection):
-            return PsBlock([self.visit(asm) for asm in obj.all_assignments])
+            return PsBlock([cast(PsStructuralNode, self.visit(asm)) for asm in obj.all_assignments])
         elif isinstance(obj, AssignmentBase):
             return cast(PsAssignment, self.visit(obj))
         elif isinstance(obj, _ExprLike):
diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py
index a9ec9d8d6fd6292eaccc58ce30bc24a64ac547ae..7aac0d412dbf5068c5e36a5740825dbd6e1eb6e5 100644
--- a/src/pystencils/backend/platforms/cuda.py
+++ b/src/pystencils/backend/platforms/cuda.py
@@ -197,7 +197,7 @@ class CudaPlatform(GenericGpu):
 
     @property
     def required_headers(self) -> set[str]:
-        return {'"gpu_defines.h"'}
+        return {'"pystencils_runtime/hip.h"'}  # TODO: move to HipPlatform once it is introduced
 
     def materialize_iteration_space(
         self, body: PsBlock, ispace: IterationSpace
diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py
index eae2b7598bfa43cf5379fe8782233be11d0dfef2..b613f3756d8708bcb844ee91c29346f388497010 100644
--- a/src/pystencils/backend/platforms/sycl.py
+++ b/src/pystencils/backend/platforms/sycl.py
@@ -25,12 +25,13 @@ from ..extensions.cpp import CppMethodCall
 
 from ..kernelcreation import KernelCreationContext, AstFactory
 from ..constants import PsConstant
-from .generic_gpu import GenericGpu
 from ..exceptions import MaterializationError
 from ...types import PsCustomType, PsIeeeFloatType, constify, PsIntegerType
 
+from .platform import Platform
 
-class SyclPlatform(GenericGpu):
+
+class SyclPlatform(Platform):
 
     def __init__(
         self,
diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py
index 10e6b39d844335eef3c0c02f2e5a9c5a7be4b94f..c9e8b3994cba3b1dd3c52e85acf405c00b4817f4 100644
--- a/src/pystencils/backend/transformations/add_pragmas.py
+++ b/src/pystencils/backend/transformations/add_pragmas.py
@@ -6,7 +6,7 @@ from collections import defaultdict
 
 from ..kernelcreation import KernelCreationContext
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock, PsLoop, PsPragma
+from ..ast.structural import PsBlock, PsLoop, PsPragma, PsStructuralNode
 from ..ast.expressions import PsExpression
 
 from ...types import PsScalarType
@@ -56,13 +56,12 @@ class InsertPragmasAtLoops:
             self._insertions[ins.loop_nesting_depth].append(ins)
 
     def __call__(self, node: PsAstNode) -> PsAstNode:
-        is_loop = isinstance(node, PsLoop)
-        if is_loop:
+        if isinstance(node, PsLoop):
             node = PsBlock([node])
 
         self.visit(node, Nesting(0))
 
-        if is_loop and len(node.children) == 1:
+        if isinstance(node, PsLoop) and len(node.children) == 1:
             node = node.children[0]
 
         return node
@@ -73,7 +72,7 @@ class InsertPragmasAtLoops:
                 return
 
             case PsBlock(children):
-                new_children: list[PsAstNode] = []
+                new_children: list[PsStructuralNode] = []
                 for c in children:
                     if isinstance(c, PsLoop):
                         nest.has_inner_loops = True
@@ -92,8 +91,8 @@ class InsertPragmasAtLoops:
                 node.children = new_children
 
             case other:
-                for c in other.children:
-                    self.visit(c, nest)
+                for child in other.children:
+                    self.visit(child, nest)
 
 
 class AddOpenMP:
diff --git a/src/pystencils/backend/transformations/ast_vectorizer.py b/src/pystencils/backend/transformations/ast_vectorizer.py
index ab4401f9ca0142d9cfeec258eeb34fb2a7f6e8eb..c793c424d2417cbbdcc0cf3782e696c7c9226bb6 100644
--- a/src/pystencils/backend/transformations/ast_vectorizer.py
+++ b/src/pystencils/backend/transformations/ast_vectorizer.py
@@ -18,6 +18,7 @@ from ..ast.structural import (
     PsAssignment,
     PsLoop,
     PsEmptyLeafMixIn,
+    PsStructuralNode,
 )
 from ..ast.expressions import (
     PsExpression,
@@ -268,6 +269,18 @@ class AstVectorizer:
         """
         return self.visit(node, vc)
 
+    @overload
+    def visit(self, node: PsStructuralNode, vc: VectorizationContext) -> PsStructuralNode:
+        pass
+
+    @overload
+    def visit(self, node: PsExpression, vc: VectorizationContext) -> PsExpression:
+        pass
+
+    @overload
+    def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode:
+        pass
+
     def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode:
         """Vectorize a subtree."""
 
diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py
index f098d82df1ce6a748097756aa1616a72e57487b5..69dd1dd11d726e597c15ece772846ba8cd84acba 100644
--- a/src/pystencils/backend/transformations/eliminate_branches.py
+++ b/src/pystencils/backend/transformations/eliminate_branches.py
@@ -1,7 +1,9 @@
+from typing import cast
+
 from ..kernelcreation import KernelCreationContext
 from ..ast import PsAstNode
 from ..ast.analysis import collect_undefined_symbols
-from ..ast.structural import PsLoop, PsBlock, PsConditional
+from ..ast.structural import PsLoop, PsBlock, PsConditional, PsStructuralNode
 from ..ast.expressions import (
     PsAnd,
     PsCast,
@@ -71,9 +73,9 @@ class EliminateBranches:
                 ec.enclosing_loops.pop()
 
             case PsBlock(statements):
-                statements_new: list[PsAstNode] = []
+                statements_new: list[PsStructuralNode] = []
                 for stmt in statements:
-                    statements_new.append(self.visit(stmt, ec))
+                    statements_new.append(cast(PsStructuralNode, self.visit(stmt, ec)))
                 node.statements = statements_new
 
             case PsConditional():
diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py
index ab1cabc557a88b03766e6a9fb2ab44a84a5711da..3a07cb56fcb8f1c60107b5b1883c679191429e7e 100644
--- a/src/pystencils/backend/transformations/eliminate_constants.py
+++ b/src/pystencils/backend/transformations/eliminate_constants.py
@@ -6,7 +6,7 @@ import numpy as np
 from ..kernelcreation import KernelCreationContext, Typifier
 
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock, PsDeclaration
+from ..ast.structural import PsBlock, PsDeclaration, PsStructuralNode
 from ..ast.expressions import (
     PsExpression,
     PsConstantExpr,
@@ -36,6 +36,7 @@ from ..ast.expressions import (
 )
 from ..ast.vector import PsVecBroadcast
 from ..ast.util import AstEqWrapper
+from ..exceptions import PsInternalCompilerError
 
 from ..constants import PsConstant
 from ..memory import PsSymbol
@@ -138,6 +139,11 @@ class EliminateConstants:
         node = self.visit(node, ecc)
 
         if ecc.extractions:
+            if not isinstance(node, PsStructuralNode):
+                raise PsInternalCompilerError(
+                    f"Cannot extract constant expressions from outermost node {node}"
+                )
+
             prepend_decls = [
                 PsDeclaration(PsExpression.make(symb), expr)
                 for symb, expr in ecc.extractions
diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
index f0e4cc9f19f1a046125bb3e8aab5302a9df2790c..f7fe81ad736981bee6f38427fbd4face73f0c455 100644
--- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
+++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
@@ -2,7 +2,7 @@ from typing import cast
 
 from ..kernelcreation import KernelCreationContext
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment
+from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment, PsStructuralNode
 from ..ast.expressions import (
     PsExpression,
     PsSymbolExpr,
@@ -99,7 +99,7 @@ class HoistLoopInvariantDeclarations:
                     return temp_block
 
             case PsBlock(statements):
-                statements_new: list[PsAstNode] = []
+                statements_new: list[PsStructuralNode] = []
                 for stmt in statements:
                     if isinstance(stmt, PsLoop):
                         loop = stmt
@@ -153,7 +153,7 @@ class HoistLoopInvariantDeclarations:
                 return
 
             case PsBlock(statements):
-                statements_new: list[PsAstNode] = []
+                statements_new: list[PsStructuralNode] = []
                 for stmt in statements:
                     if isinstance(stmt, PsLoop):
                         loop = stmt
@@ -178,7 +178,7 @@ class HoistLoopInvariantDeclarations:
         This method processes only statements of the given block, and any blocks directly nested inside it.
         It does not descend into control structures like conditionals and nested loops.
         """
-        statements_new: list[PsAstNode] = []
+        statements_new: list[PsStructuralNode] = []
 
         for node in block.statements:
             if isinstance(node, PsDeclaration):
diff --git a/src/pystencils/backend/transformations/rewrite.py b/src/pystencils/backend/transformations/rewrite.py
index 59241c295f42eeaf60f4cd03a5138214fdbd6c50..8dff9e45ec283fc6c3712c2e77ff56a9b2aaeae5 100644
--- a/src/pystencils/backend/transformations/rewrite.py
+++ b/src/pystencils/backend/transformations/rewrite.py
@@ -2,7 +2,7 @@ from typing import overload
 
 from ..memory import PsSymbol
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock
+from ..ast.structural import PsStructuralNode, PsBlock
 from ..ast.expressions import PsExpression, PsSymbolExpr
 
 
@@ -18,6 +18,13 @@ def substitute_symbols(
     pass
 
 
+@overload
+def substitute_symbols(
+    node: PsStructuralNode, subs: dict[PsSymbol, PsExpression]
+) -> PsStructuralNode:
+    pass
+
+
 @overload
 def substitute_symbols(
     node: PsAstNode, subs: dict[PsSymbol, PsExpression]
diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py
index cc3411249d0fd95a68ca51f8761c3096bc09f2d2..1e88b29721b2e12740f48ab8db2b0ca81f4c634a 100644
--- a/src/pystencils/codegen/driver.py
+++ b/src/pystencils/codegen/driver.py
@@ -1,5 +1,5 @@
 from __future__ import annotations
-from typing import cast, Sequence, Iterable, Callable, TYPE_CHECKING
+from typing import cast, Sequence, Callable, TYPE_CHECKING
 from dataclasses import dataclass, replace
 
 from .target import Target
@@ -15,8 +15,7 @@ from .config import (
 from .kernel import Kernel, GpuKernel
 from .properties import PsSymbolProperty, FieldBasePtr
 from .parameters import Parameter
-from ..backend.functions import PsReductionFunction, ReductionFunctions
-from ..backend.ast.expressions import PsSymbolExpr, PsCall, PsMemAcc, PsConstantExpr
+from .functions import Lambda
 from .gpu_indexing import GpuIndexing, GpuLaunchConfiguration
 
 from ..field import Field
@@ -24,6 +23,8 @@ from ..types import PsIntegerType, PsScalarType
 
 from ..backend.memory import PsSymbol
 from ..backend.ast import PsAstNode
+from ..backend.functions import PsReductionFunction, ReductionFunctions
+from ..backend.ast.expressions import PsExpression, PsSymbolExpr, PsCall, PsMemAcc, PsConstantExpr
 from ..backend.ast.structural import PsBlock, PsLoop, PsDeclaration, PsAssignment
 from ..backend.ast.analysis import collect_undefined_symbols, collect_required_headers
 from ..backend.kernelcreation import (
@@ -220,20 +221,20 @@ class DefaultKernelCreationDriver:
         canonicalize = CanonicalizeSymbols(self._ctx, True)
         kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
 
-        if self._target.is_cpu():
-            return create_cpu_kernel_function(
-                self._ctx,
+        kernel_factory = KernelFactory(self._ctx)
+
+        if self._target.is_cpu() or self._target == Target.SYCL:
+            return kernel_factory.create_generic_kernel(
                 self._platform,
                 kernel_ast,
                 self._cfg.get_option("function_name"),
                 self._target,
                 self._cfg.get_jit(),
             )
-        else:
+        elif self._target.is_gpu():
             assert self._gpu_indexing is not None
 
-            return create_gpu_kernel_function(
-                self._ctx,
+            return kernel_factory.create_gpu_kernel(
                 self._platform,
                 kernel_ast,
                 self._cfg.get_option("function_name"),
@@ -241,6 +242,8 @@ class DefaultKernelCreationDriver:
                 self._cfg.get_jit(),
                 self._gpu_indexing.get_launch_config_factory(),
             )
+        else:
+            assert False, "unexpected target"
 
     def parse_kernel_body(
         self,
@@ -451,23 +454,11 @@ class DefaultKernelCreationDriver:
                     f"No platform is currently available for CPU target {self._target}"
                 )
 
-        elif Target._GPU in self._target:
+        elif self._target.is_gpu():
             gpu_opts = self._cfg.gpu
             omit_range_check: bool = gpu_opts.get_option("omit_range_check")
 
             match self._target:
-                case Target.SYCL:
-                    from ..backend.platforms import SyclPlatform
-
-                    auto_block_size: bool = self._cfg.sycl.get_option(
-                        "automatic_block_size"
-                    )
-
-                    return SyclPlatform(
-                        self._ctx,
-                        omit_range_check=omit_range_check,
-                        automatic_block_size=auto_block_size,
-                    )
                 case Target.CUDA:
                     from ..backend.platforms import CudaPlatform
 
@@ -482,89 +473,102 @@ class DefaultKernelCreationDriver:
                         omit_range_check=omit_range_check,
                         thread_mapping=thread_mapping,
                     )
+        elif self._target == Target.SYCL:
+            from ..backend.platforms import SyclPlatform
+
+            auto_block_size: bool = self._cfg.sycl.get_option("automatic_block_size")
+            omit_range_check = self._cfg.gpu.get_option("omit_range_check")
+
+            return SyclPlatform(
+                self._ctx,
+                omit_range_check=omit_range_check,
+                automatic_block_size=auto_block_size,
+            )
 
         raise NotImplementedError(
             f"Code generation for target {self._target} not implemented"
         )
 
 
-def create_cpu_kernel_function(
-    ctx: KernelCreationContext,
-    platform: Platform,
-    body: PsBlock,
-    function_name: str,
-    target_spec: Target,
-    jit: JitBase,
-) -> Kernel:
-    undef_symbols = collect_undefined_symbols(body)
-
-    params = _get_function_params(ctx, undef_symbols)
-    req_headers = _get_headers(ctx, platform, body)
-
-    kfunc = Kernel(body, target_spec, function_name, params, req_headers, jit)
-    kfunc.metadata.update(ctx.metadata)
-    return kfunc
-
-
-def create_gpu_kernel_function(
-    ctx: KernelCreationContext,
-    platform: Platform,
-    body: PsBlock,
-    function_name: str,
-    target_spec: Target,
-    jit: JitBase,
-    launch_config_factory: Callable[[], GpuLaunchConfiguration],
-) -> GpuKernel:
-    undef_symbols = collect_undefined_symbols(body)
-
-    params = _get_function_params(ctx, undef_symbols)
-    req_headers = _get_headers(ctx, platform, body)
-
-    kfunc = GpuKernel(
-        body,
-        target_spec,
-        function_name,
-        params,
-        req_headers,
-        jit,
-        launch_config_factory,
-    )
-    kfunc.metadata.update(ctx.metadata)
-    return kfunc
-
-
-def _symbol_to_param(ctx: KernelCreationContext, symbol: PsSymbol):
-    from pystencils.backend.memory import BufferBasePtr, BackendPrivateProperty
-
-    props: set[PsSymbolProperty] = set()
-    for prop in symbol.properties:
-        match prop:
-            case BufferBasePtr(buf):
-                field = ctx.find_field(buf.name)
-                props.add(FieldBasePtr(field))
-            case BackendPrivateProperty():
-                pass
-            case _:
-                props.add(prop)
-
-    return Parameter(symbol.name, symbol.get_dtype(), props)
-
-
-def _get_function_params(
-    ctx: KernelCreationContext, symbols: Iterable[PsSymbol]
-) -> list[Parameter]:
-    params: list[Parameter] = [_symbol_to_param(ctx, s) for s in symbols]
-    params.sort(key=lambda p: p.name)
-    return params
-
-
-def _get_headers(
-    ctx: KernelCreationContext, platform: Platform, body: PsBlock
-) -> set[str]:
-    req_headers = collect_required_headers(body)
-    req_headers |= platform.required_headers
-    req_headers |= ctx.required_headers
-    return req_headers
+class KernelFactory:
+    """Factory for wrapping up backend and IR objects into exportable kernels and function objects."""
+
+    def __init__(self, ctx: KernelCreationContext):
+        self._ctx = ctx
+
+    def create_lambda(self, expr: PsExpression) -> Lambda:
+        """Create a Lambda from an expression."""
+        params = self._get_function_params(expr)
+        return Lambda(expr, params)
+
+    def create_generic_kernel(
+        self,
+        platform: Platform,
+        body: PsBlock,
+        function_name: str,
+        target_spec: Target,
+        jit: JitBase,
+    ) -> Kernel:
+        """Create a kernel for a generic target"""
+        params = self._get_function_params(body)
+        req_headers = self._get_headers(platform, body)
+
+        kfunc = Kernel(body, target_spec, function_name, params, req_headers, jit)
+        kfunc.metadata.update(self._ctx.metadata)
+        return kfunc
+
+    def create_gpu_kernel(
+        self,
+        platform: Platform,
+        body: PsBlock,
+        function_name: str,
+        target_spec: Target,
+        jit: JitBase,
+        launch_config_factory: Callable[[], GpuLaunchConfiguration],
+    ) -> GpuKernel:
+        """Create a kernel for a GPU target"""
+        params = self._get_function_params(body)
+        req_headers = self._get_headers(platform, body)
+
+        kfunc = GpuKernel(
+            body,
+            target_spec,
+            function_name,
+            params,
+            req_headers,
+            jit,
+            launch_config_factory,
+        )
+        kfunc.metadata.update(self._ctx.metadata)
+        return kfunc
+
+    def _symbol_to_param(self, symbol: PsSymbol):
+        from pystencils.backend.memory import BufferBasePtr, BackendPrivateProperty
+
+        props: set[PsSymbolProperty] = set()
+        for prop in symbol.properties:
+            match prop:
+                case BufferBasePtr(buf):
+                    field = self._ctx.find_field(buf.name)
+                    props.add(FieldBasePtr(field))
+                case BackendPrivateProperty():
+                    pass
+                case _:
+                    props.add(prop)
+
+        return Parameter(symbol.name, symbol.get_dtype(), props)
+
+    def _get_function_params(self, ast: PsAstNode) -> list[Parameter]:
+        symbols = collect_undefined_symbols(ast)
+        params: list[Parameter] = [self._symbol_to_param(s) for s in symbols]
+        params.sort(key=lambda p: p.name)
+        return params
+
+    def _get_headers(self, platform: Platform, body: PsBlock) -> set[str]:
+        req_headers = collect_required_headers(body)
+        req_headers |= platform.required_headers
+        req_headers |= self._ctx.required_headers
+        return req_headers
 
 
 @dataclass
diff --git a/src/pystencils/codegen/functions.py b/src/pystencils/codegen/functions.py
index f6be3b1f3446c6b9a25a0013f0e06d099edf5bed..c24dbaffb9947d68c854f83532e87386914c6677 100644
--- a/src/pystencils/codegen/functions.py
+++ b/src/pystencils/codegen/functions.py
@@ -4,21 +4,12 @@ from typing import Sequence, Any
 from .parameters import Parameter
 from ..types import PsType
 
-from ..backend.kernelcreation import KernelCreationContext
 from ..backend.ast.expressions import PsExpression
 
 
 class Lambda:
     """A one-line function emitted by the code generator as an auxiliary object."""
 
-    @staticmethod
-    def from_expression(ctx: KernelCreationContext, expr: PsExpression):
-        from ..backend.ast.analysis import collect_undefined_symbols
-        from .driver import _get_function_params
-
-        params = _get_function_params(ctx, collect_undefined_symbols(expr))
-        return Lambda(expr, params)
-
     def __init__(self, expr: PsExpression, params: Sequence[Parameter]):
         self._expr = expr
         self._params = tuple(params)
diff --git a/src/pystencils/codegen/gpu_indexing.py b/src/pystencils/codegen/gpu_indexing.py
index 2d22ec624856d9cf8a0b825845fee04caaa4ee74..27d6fc817d5a9193c3faa4b170d907987fe6022e 100644
--- a/src/pystencils/codegen/gpu_indexing.py
+++ b/src/pystencils/codegen/gpu_indexing.py
@@ -228,8 +228,10 @@ class GpuIndexing:
         self._manual_launch_grid = manual_launch_grid
 
         from ..backend.kernelcreation import AstFactory
+        from .driver import KernelFactory
 
-        self._factory = AstFactory(self._ctx)
+        self._ast_factory = AstFactory(self._ctx)
+        self._kernel_factory = KernelFactory(self._ctx)
 
     def get_thread_mapping(self) -> ThreadMapping:
         """Retrieve a thread mapping object for use by the backend"""
@@ -265,9 +267,14 @@ class GpuIndexing:
                 f" for a {rank}-dimensional kernel."
             )
 
+        work_items_expr += tuple(
+            self._ast_factory.parse_index(1)
+            for _ in range(3 - rank)
+        )
+        
         num_work_items = cast(
             _Dim3Lambda,
-            tuple(Lambda.from_expression(self._ctx, wit) for wit in work_items_expr),
+            tuple(self._kernel_factory.create_lambda(wit) for wit in work_items_expr),
         )
 
         def factory():
@@ -305,15 +312,15 @@ class GpuIndexing:
             raise ValueError(f"Iteration space rank is too large: {rank}")
 
         block_size = (
-            Lambda.from_expression(self._ctx, work_items[0]),
-            Lambda.from_expression(self._ctx, self._factory.parse_index(1)),
-            Lambda.from_expression(self._ctx, self._factory.parse_index(1)),
+            self._kernel_factory.create_lambda(work_items[0]),
+            self._kernel_factory.create_lambda(self._ast_factory.parse_index(1)),
+            self._kernel_factory.create_lambda(self._ast_factory.parse_index(1)),
         )
 
         grid_size = tuple(
-            Lambda.from_expression(self._ctx, wit) for wit in work_items[1:]
+            self._kernel_factory.create_lambda(wit) for wit in work_items[1:]
         ) + tuple(
-            Lambda.from_expression(self._ctx, self._factory.parse_index(1))
+            self._kernel_factory.create_lambda(self._ast_factory.parse_index(1))
             for _ in range(4 - rank)
         )
 
@@ -350,7 +357,7 @@ class GpuIndexing:
                 return tuple(ispace.actual_iterations(dim) for dim in dimensions)
 
             case SparseIterationSpace():
-                return (self._factory.parse_index(ispace.index_list.shape[0]),)
+                return (self._ast_factory.parse_index(ispace.index_list.shape[0]),)
 
             case _:
                 assert False, "unexpected iteration space"
diff --git a/src/pystencils/codegen/target.py b/src/pystencils/codegen/target.py
index b847a8139a8725c9c926b7c12c9556aba3ec6e87..0d724b87730f0ec327772bccbb55a8bfff7c8ddd 100644
--- a/src/pystencils/codegen/target.py
+++ b/src/pystencils/codegen/target.py
@@ -89,10 +89,13 @@ class Target(Flag):
     GPU = CUDA
     """Alias for `Target.CUDA`, for backward compatibility."""
 
-    SYCL = _GPU | _SYCL
+    SYCL = _SYCL
     """SYCL kernel target.
     
     Generate a function to be called within a SYCL parallel command.
+
+    ..  note::
+        The SYCL target is experimental and not thoroughly tested yet.
     """
 
     def is_automatic(self) -> bool:
diff --git a/src/pystencils/include/PyStencilsField.h b/src/pystencils/include/PyStencilsField.h
deleted file mode 100644
index 3055cae2365279e28fdcaab4353779b97ca27d35..0000000000000000000000000000000000000000
--- a/src/pystencils/include/PyStencilsField.h
+++ /dev/null
@@ -1,19 +0,0 @@
-#pragma once
-
-extern "C++" {
-#ifdef __CUDA_ARCH__
-template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField {
-  DTYPE_T *data;
-  DTYPE_T shape[DIMENSION];
-  DTYPE_T stride[DIMENSION];
-};
-#else
-#include <array>
-
-template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField {
-  DTYPE_T *data;
-  std::array<DTYPE_T, DIMENSION> shape;
-  std::array<DTYPE_T, DIMENSION> stride;
-};
-#endif
-}
diff --git a/src/pystencils/include/half_precision.h b/src/pystencils/include/pystencils_runtime/half.h
similarity index 100%
rename from src/pystencils/include/half_precision.h
rename to src/pystencils/include/pystencils_runtime/half.h
diff --git a/src/pystencils/include/gpu_defines.h b/src/pystencils/include/pystencils_runtime/hip.h
similarity index 95%
rename from src/pystencils/include/gpu_defines.h
rename to src/pystencils/include/pystencils_runtime/hip.h
index 34cff79dea2f14399622a0026e362a4832bd739c..b0b4d967911688bc302537110e54ea0901e661b8 100644
--- a/src/pystencils/include/gpu_defines.h
+++ b/src/pystencils/include/pystencils_runtime/hip.h
@@ -1,11 +1,5 @@
 #pragma once
 
-#define POS_INFINITY __int_as_float(0x7f800000)
-#define NEG_INFINITY __int_as_float(0xff800000)
-#ifndef INFINITY
-#define INFINITY POS_INFINITY
-#endif
-
 #ifdef __HIPCC_RTC__
 typedef __hip_uint8_t uint8_t;
 typedef __hip_int8_t int8_t;
diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py
index 4ea991e28cbadc68c81129bcff1a7c02d689bf07..4c3c8945e34b55cd557c180397d72086f766bf7e 100644
--- a/src/pystencils/jit/gpu_cupy.py
+++ b/src/pystencils/jit/gpu_cupy.py
@@ -252,8 +252,8 @@ class CupyJit(JitBase):
         headers = self._runtime_headers
         headers |= kfunc.required_headers
 
-        if '"half_precision.h"' in headers:
-            headers.remove('"half_precision.h"')
+        if '"pystencils_runtime/half.h"' in headers:
+            headers.remove('"pystencils_runtime/half.h"')
             if cp.cuda.runtime.is_hip:
                 headers.add("<hip/hip_fp16.h>")
             else:
diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py
index 825ac1d5d35fde0f26c5a9ebadb55ec43004c9ae..8dea97ca43539b966e260ac7a206b0e26b3b2110 100644
--- a/src/pystencils/types/types.py
+++ b/src/pystencils/types/types.py
@@ -661,7 +661,7 @@ class PsIeeeFloatType(PsScalarType):
     @property
     def required_headers(self) -> set[str]:
         if self._width == 16:
-            return {'"half_precision.h"'}
+            return {'"pystencils_runtime/half.h"'}
         else:
             return set()
 
@@ -672,7 +672,7 @@ class PsIeeeFloatType(PsScalarType):
 
         match self.width:
             case 16:
-                return f"((half) {value})"  # see include/half_precision.h
+                return f"((half) {value})"  # see include/pystencils_runtime/half.h
             case 32:
                 return f"{value}f"
             case 64:
diff --git a/tests/kernelcreation/test_gpu.py b/tests/kernelcreation/test_gpu.py
index 10b37e610cebd23c9fc961f14118aee5f24582c4..f1905b1fcb7c7406f43cfb94af2928b6f35bc3f8 100644
--- a/tests/kernelcreation/test_gpu.py
+++ b/tests/kernelcreation/test_gpu.py
@@ -31,7 +31,7 @@ except ImportError:
 @pytest.mark.parametrize("indexing_scheme", ["linear3d", "blockwise4d"])
 @pytest.mark.parametrize("omit_range_check", [False, True])
 @pytest.mark.parametrize("manual_grid", [False, True])
-def test_indexing_options(
+def test_indexing_options_3d(
     indexing_scheme: str, omit_range_check: bool, manual_grid: bool
 ):
     src, dst = fields("src, dst: [3D]")
@@ -76,6 +76,52 @@ def test_indexing_options(
     cp.testing.assert_allclose(dst_arr, expected)
 
 
+@pytest.mark.parametrize("indexing_scheme", ["linear3d", "blockwise4d"])
+@pytest.mark.parametrize("omit_range_check", [False, True])
+@pytest.mark.parametrize("manual_grid", [False, True])
+def test_indexing_options_2d(
+    indexing_scheme: str, omit_range_check: bool, manual_grid: bool
+):
+    src, dst = fields("src, dst: [2D]")
+    asm = Assignment(
+        dst.center(),
+        src[-1, 0]
+        + src[1, 0]
+        + src[0, -1]
+        + src[0, 1]
+    )
+
+    cfg = CreateKernelConfig(target=Target.CUDA)
+    cfg.gpu.indexing_scheme = indexing_scheme
+    cfg.gpu.omit_range_check = omit_range_check
+    cfg.gpu.manual_launch_grid = manual_grid
+
+    ast = create_kernel(asm, cfg)
+    kernel = ast.compile()
+
+    src_arr = cp.ones((18, 42))
+    dst_arr = cp.zeros_like(src_arr)
+
+    if manual_grid:
+        match indexing_scheme:
+            case "linear3d":
+                kernel.launch_config.block_size = (10, 8, 1)
+                kernel.launch_config.grid_size = (4, 2, 1)
+            case "blockwise4d":
+                kernel.launch_config.block_size = (40, 1, 1)
+                kernel.launch_config.grid_size = (16, 1, 1)
+
+    elif indexing_scheme == "linear3d":
+        kernel.launch_config.block_size = (10, 8, 1)
+
+    kernel(src=src_arr, dst=dst_arr)
+
+    expected = cp.zeros_like(src_arr)
+    expected[1:-1, 1:-1].fill(4.0)
+
+    cp.testing.assert_allclose(dst_arr, expected)
+
+
 def test_invalid_indexing_schemes():
     src, dst = fields("src, dst: [4D]")
     asm = Assignment(src.center(0), dst.center(0))
diff --git a/tests/kernelcreation/test_sycl_codegen.py b/tests/kernelcreation/test_sycl_codegen.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6907c9965b4b80a5b97865063170dfdc3654615
--- /dev/null
+++ b/tests/kernelcreation/test_sycl_codegen.py
@@ -0,0 +1,45 @@
+"""
+Since we don't have a JIT compiler for SYCL, these tests can only
+perform dry-dock testing.
+If the SYCL target should ever become non-experimental, we need to
+find a way to properly test SYCL kernels in execution.
+
+These tests primarily check that the code generation driver runs
+successfully for the SYCL target.
+"""
+
+import sympy as sp
+from pystencils import (
+    create_kernel,
+    Target,
+    fields,
+    Assignment,
+    CreateKernelConfig,
+)
+
+
+def test_sycl_kernel_static():
+    src, dst = fields("src, dst: [2D]")
+    asm = Assignment(dst.center(), sp.sin(src.center()) + sp.cos(src.center()))
+
+    cfg = CreateKernelConfig(target=Target.SYCL)
+    kernel = create_kernel(asm, cfg)
+
+    code_string = kernel.get_c_code()
+
+    assert "sycl::id< 2 >" in code_string
+    assert "sycl::sin(" in code_string
+    assert "sycl::cos(" in code_string
+
+
+def test_sycl_kernel_manual_block_size():
+    src, dst = fields("src, dst: [2D]")
+    asm = Assignment(dst.center(), sp.sin(src.center()) + sp.cos(src.center()))
+
+    cfg = CreateKernelConfig(target=Target.SYCL)
+    cfg.sycl.automatic_block_size = False
+    kernel = create_kernel(asm, cfg)
+
+    code_string = kernel.get_c_code()
+
+    assert "sycl::nd_item< 2 >" in code_string
diff --git a/tests/nbackend/test_vectorization.py b/tests/nbackend/test_vectorization.py
index b60dc24774566d67eaa271c6ab775374746d89cf..fecade65d97afcaae4382bcc2ced119b2a957bed 100644
--- a/tests/nbackend/test_vectorization.py
+++ b/tests/nbackend/test_vectorization.py
@@ -20,7 +20,7 @@ from pystencils.backend.transformations import (
     LowerToC,
 )
 from pystencils.backend.constants import PsConstant
-from pystencils.codegen.driver import create_cpu_kernel_function
+from pystencils.codegen.driver import KernelFactory
 from pystencils.jit import LegacyCpuJit
 from pystencils import Target, fields, Assignment, Field
 from pystencils.field import create_numpy_array_with_layout
@@ -135,8 +135,8 @@ def create_vector_kernel(
     lower = LowerToC(ctx)
     loop_nest = lower(loop_nest)
 
-    func = create_cpu_kernel_function(
-        ctx,
+    kfactory = KernelFactory(ctx)
+    func = kfactory.create_generic_kernel(
         platform,
         PsBlock([loop_nest]),
         "vector_kernel",