diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py
index 4374ccda4ef8c508a909975f254cbf32936912ae..028e4b88532cf1e3a803c056983c5d1f29cfd865 100644
--- a/src/pystencils/__init__.py
+++ b/src/pystencils/__init__.py
@@ -19,7 +19,7 @@ from .cache import clear_cache
 from .kernel_decorator import kernel, kernel_config
 from .kernelcreation import create_kernel, create_staggered_kernel
 from .codegen import Kernel
-from .backend.jit import no_jit
+from .jit import no_jit
 from .backend.exceptions import KernelConstraintsError
 from .slicing import make_slice
 from .spatial_coordinates import (
diff --git a/src/pystencils/backend/constraints.py b/src/pystencils/backend/constraints.py
deleted file mode 100644
index 229f6718c65e5e4941e33aa09b5363f5962abae5..0000000000000000000000000000000000000000
--- a/src/pystencils/backend/constraints.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from __future__ import annotations
-
-from typing import Any, TYPE_CHECKING
-from dataclasses import dataclass
-
-if TYPE_CHECKING:
-    from .kernelfunction import KernelParameter
-
-
-@dataclass
-class KernelParamsConstraint:
-    condition: Any  # FIXME Implement conditions
-    message: str = ""
-
-    def to_code(self):
-        raise NotImplementedError()
-
-    def get_parameters(self) -> set[KernelParameter]:
-        raise NotImplementedError()
-
-    def __str__(self) -> str:
-        return f"{self.message} [{self.condition}]"
diff --git a/src/pystencils/backend/emission/base_printer.py b/src/pystencils/backend/emission/base_printer.py
index d721b9f895c79ebdd6a58858cc0408613fefa6e4..a4358bbf328b65aaf5e45eff5a2083ef067285a6 100644
--- a/src/pystencils/backend/emission/base_printer.py
+++ b/src/pystencils/backend/emission/base_printer.py
@@ -61,7 +61,7 @@ from ..constants import PsConstant
 from ...types import PsType
 
 if TYPE_CHECKING:
-    from ...codegen import Kernel, GpuKernel
+    from ...codegen import Kernel
 
 
 class EmissionError(Exception):
@@ -175,6 +175,7 @@ class BasePrinter(ABC):
         self._indent_width = indent_width
 
     def __call__(self, obj: PsAstNode | Kernel) -> str:
+        from ...codegen import Kernel
         if isinstance(obj, Kernel):
             sig = self.print_signature(obj)
             body_code = self.visit(obj.body, PrinterCtx())
@@ -383,6 +384,8 @@ class BasePrinter(ABC):
         return signature
 
     def _func_prefix(self, func: Kernel):
+        from ...codegen import GpuKernel
+
         if isinstance(func, GpuKernel) and func.target == Target.CUDA:
             return "__global__"
         else:
diff --git a/src/pystencils/backend/emission/c_printer.py b/src/pystencils/backend/emission/c_printer.py
index 95e27bd66732f04e3d20767cf8c6d35d0cfd2450..90a7e54e22b3eb14866c9260c85247baf8b4f340 100644
--- a/src/pystencils/backend/emission/c_printer.py
+++ b/src/pystencils/backend/emission/c_printer.py
@@ -1,18 +1,23 @@
+from __future__ import annotations
+from typing import TYPE_CHECKING
+
 from pystencils.backend.ast.astnode import PsAstNode
 from pystencils.backend.constants import PsConstant
 from pystencils.backend.emission.base_printer import PrinterCtx, EmissionError
 from pystencils.backend.memory import PsSymbol
 from .base_printer import BasePrinter
 
-from ..kernelfunction import KernelFunction
 from ...types import PsType, PsArrayType, PsScalarType, PsTypeError
 from ..ast.expressions import PsBufferAcc
 from ..ast.vector import PsVecMemAcc
 
+if TYPE_CHECKING:
+    from ...codegen import Kernel
+
 
-def emit_code(kernel: KernelFunction):
+def emit_code(ast: PsAstNode | Kernel):
     printer = CAstPrinter()
-    return printer(kernel)
+    return printer(ast)
 
 
 class CAstPrinter(BasePrinter):
diff --git a/src/pystencils/backend/emission/ir_printer.py b/src/pystencils/backend/emission/ir_printer.py
index 124ce200d3aab9e3b111dd0481bd8bc7faad817f..ffb65181ccd71ff95dffd6d006617dadc6809eea 100644
--- a/src/pystencils/backend/emission/ir_printer.py
+++ b/src/pystencils/backend/emission/ir_printer.py
@@ -1,3 +1,6 @@
+from __future__ import annotations
+from typing import TYPE_CHECKING
+
 from pystencils.backend.constants import PsConstant
 from pystencils.backend.emission.base_printer import PrinterCtx
 from pystencils.backend.memory import PsSymbol
@@ -9,8 +12,11 @@ from ..ast import PsAstNode
 from ..ast.expressions import PsBufferAcc
 from ..ast.vector import PsVecMemAcc, PsVecBroadcast
 
+if TYPE_CHECKING:
+    from ...codegen import Kernel
+
 
-def emit_ir(ir: PsAstNode):
+def emit_ir(ir: PsAstNode | Kernel):
     """Emit the IR as C-like pseudo-code for inspection."""
     ir_printer = IRAstPrinter()
     return ir_printer(ir)
diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py
index bb7bd708d0d1b3315603b2b152a95cfbed98b28b..39fb8ef6dac855553b7e18d2a688c67ca45fb227 100644
--- a/src/pystencils/backend/kernelcreation/context.py
+++ b/src/pystencils/backend/kernelcreation/context.py
@@ -18,7 +18,6 @@ from ...types import (
     PsPointerType,
     deconstify,
 )
-from ..constraints import KernelParamsConstraint
 from ..exceptions import PsInternalCompilerError, KernelConstraintsError
 
 from .iteration_space import IterationSpace, FullIterationSpace, SparseIterationSpace
@@ -81,7 +80,6 @@ class KernelCreationContext:
 
         self._ispace: IterationSpace | None = None
 
-        self._constraints: list[KernelParamsConstraint] = []
         self._req_headers: set[str] = set()
 
         self._metadata: dict[str, Any] = dict()
@@ -96,15 +94,6 @@ class KernelCreationContext:
         """Data type used by default for index expressions"""
         return self._index_dtype
 
-    #   Constraints
-
-    def add_constraints(self, *constraints: KernelParamsConstraint):
-        self._constraints += constraints
-
-    @property
-    def constraints(self) -> tuple[KernelParamsConstraint, ...]:
-        return tuple(self._constraints)
-
     @property
     def metadata(self) -> dict[str, Any]:
         return self._metadata
diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py
index a7802c931a5c4e9da3d008b48caa10a954130ea9..031a0d843f3f5a648f2cd8c390134ba308c1c833 100644
--- a/src/pystencils/backend/kernelcreation/iteration_space.py
+++ b/src/pystencils/backend/kernelcreation/iteration_space.py
@@ -457,7 +457,7 @@ def create_full_iteration_space(
     # Otherwise, if an iteration slice was specified, use that
     # Otherwise, use the inferred ghost layers
 
-    from ...codegen.config import AUTO
+    from ...codegen.config import AUTO, _AUTO_TYPE
 
     if ghost_layers is AUTO:
         if len(domain_field_accesses) > 0:
diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py
deleted file mode 100644
index 3c7e103b34da067c82bfb85493fb5eb8059f6ff2..0000000000000000000000000000000000000000
--- a/src/pystencils/backend/kernelfunction.py
+++ /dev/null
@@ -1,113 +0,0 @@
-from __future__ import annotations
-
-from warnings import warn
-from typing import Callable, Sequence, Iterable, Any, TYPE_CHECKING
-from itertools import chain
-
-from .._deprecation import _deprecated
-
-from .ast.structural import PsBlock
-from .ast.analysis import collect_required_headers, collect_undefined_symbols
-from .memory import PsSymbol
-from ..codegen.properties import (
-    PsSymbolProperty,
-    _FieldProperty,
-    FieldShape,
-    FieldStride,
-    FieldBasePtr,
-)
-from .kernelcreation.context import KernelCreationContext
-from .platforms import Platform, GpuThreadsRange
-
-from .constraints import KernelParamsConstraint
-from ..types import PsType
-
-from ..codegen.target import Target
-from ..field import Field
-from ..sympyextensions import TypedSymbol
-
-if TYPE_CHECKING:
-    from .jit import JitBase
-
-
-
-
-def create_cpu_kernel_function(
-    ctx: KernelCreationContext,
-    platform: Platform,
-    body: PsBlock,
-    function_name: str,
-    target_spec: Target,
-    jit: JitBase,
-):
-    undef_symbols = collect_undefined_symbols(body)
-
-    params = _get_function_params(ctx, undef_symbols)
-    req_headers = _get_headers(ctx, platform, body)
-
-    kfunc = KernelFunction(
-        body, target_spec, function_name, params, req_headers, ctx.constraints, jit
-    )
-    kfunc.metadata.update(ctx.metadata)
-    return kfunc
-
-
-
-
-def create_gpu_kernel_function(
-    ctx: KernelCreationContext,
-    platform: Platform,
-    body: PsBlock,
-    threads_range: GpuThreadsRange | None,
-    function_name: str,
-    target_spec: Target,
-    jit: JitBase,
-):
-    undef_symbols = collect_undefined_symbols(body)
-
-    if threads_range is not None:
-        for threads in threads_range.num_work_items:
-            undef_symbols |= collect_undefined_symbols(threads)
-
-    params = _get_function_params(ctx, undef_symbols)
-    req_headers = _get_headers(ctx, platform, body)
-
-    kfunc = GpuKernelFunction(
-        body,
-        threads_range,
-        target_spec,
-        function_name,
-        params,
-        req_headers,
-        ctx.constraints,
-        jit,
-    )
-    kfunc.metadata.update(ctx.metadata)
-    return kfunc
-
-
-def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]):
-    params: list[KernelParameter] = []
-
-    from pystencils.backend.memory import BufferBasePtr
-
-    for symb in symbols:
-        props: set[PsSymbolProperty] = set()
-        for prop in symb.properties:
-            match prop:
-                case FieldShape() | FieldStride():
-                    props.add(prop)
-                case BufferBasePtr(buf):
-                    field = ctx.find_field(buf.name)
-                    props.add(FieldBasePtr(field))
-        params.append(KernelParameter(symb.name, symb.get_dtype(), props))
-
-    params.sort(key=lambda p: p.name)
-    return params
-
-
-def _get_headers(ctx: KernelCreationContext, platform: Platform, body: PsBlock):
-    req_headers = collect_required_headers(body)
-    req_headers |= platform.required_headers
-    req_headers |= ctx.required_headers
-    return req_headers
diff --git a/src/pystencils/backend/platforms/__init__.py b/src/pystencils/backend/platforms/__init__.py
index 9332453c6c1b60255f1869f011bfa661ee670ea0..589841db87efb598ffeed20d4d11db7ffcd452cc 100644
--- a/src/pystencils/backend/platforms/__init__.py
+++ b/src/pystencils/backend/platforms/__init__.py
@@ -1,6 +1,6 @@
 from .platform import Platform
 from .generic_cpu import GenericCpu, GenericVectorCpu
-from .generic_gpu import GenericGpu, GpuThreadsRange
+from .generic_gpu import GenericGpu
 from .cuda import CudaPlatform
 from .x86 import X86VectorCpu, X86VectorArch
 from .sycl import SyclPlatform
@@ -12,7 +12,6 @@ __all__ = [
     "X86VectorCpu",
     "X86VectorArch",
     "GenericGpu",
-    "GpuThreadsRange",
     "CudaPlatform",
     "SyclPlatform",
 ]
diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py
index 048bcb0d55743eb46344a5e255c778cba7f40854..b8f5356ae7bbdb961773e122e2d6d377f0f7e45f 100644
--- a/src/pystencils/backend/platforms/cuda.py
+++ b/src/pystencils/backend/platforms/cuda.py
@@ -1,9 +1,10 @@
+from __future__ import annotations
 from warnings import warn
 from typing import TYPE_CHECKING
 
 from ...types import constify
 from ..exceptions import MaterializationError
-from .generic_gpu import GenericGpu, GpuThreadsRange
+from .generic_gpu import GenericGpu
 
 from ..kernelcreation import (
     Typifier,
@@ -29,7 +30,7 @@ from ..literals import PsLiteral
 from ..functions import PsMathFunction, MathFunctions, CFunction
 
 if TYPE_CHECKING:
-    from ...codegen.config import GpuIndexingConfig
+    from ...codegen import GpuIndexingConfig, GpuThreadsRange
 
 int32 = PsSignedIntegerType(width=32, const=False)
 
@@ -54,6 +55,9 @@ class CudaPlatform(GenericGpu):
         self, ctx: KernelCreationContext, indexing_cfg: GpuIndexingConfig | None = None
     ) -> None:
         super().__init__(ctx)
+
+        from ...codegen.config import GpuIndexingConfig
+
         self._cfg = indexing_cfg if indexing_cfg is not None else GpuIndexingConfig()
         self._typify = Typifier(ctx)
 
@@ -134,7 +138,7 @@ class CudaPlatform(GenericGpu):
 
         if not self._cfg.manual_launch_grid:
             try:
-                threads_range = GpuThreadsRange.from_ispace(ispace)
+                threads_range = self.threads_from_ispace(ispace)
             except MaterializationError as e:
                 warn(
                     str(e.args[0])
@@ -212,7 +216,7 @@ class CudaPlatform(GenericGpu):
             body.statements = [sparse_idx_decl] + body.statements
             ast = body
 
-        return ast, GpuThreadsRange.from_ispace(ispace)
+        return ast, self.threads_from_ispace(ispace)
 
     def _linear_thread_idx(self, coord: int):
         block_size = BLOCK_DIM[coord]
diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py
index 0512351cda1ca22a40a9afa65aabb66dea7bde0a..da8fa64f9f6d9e8de08b84bc169ae2442d0dec42 100644
--- a/src/pystencils/backend/platforms/generic_gpu.py
+++ b/src/pystencils/backend/platforms/generic_gpu.py
@@ -1,5 +1,5 @@
 from __future__ import annotations
-from typing import Sequence
+from typing import TYPE_CHECKING
 from abc import abstractmethod
 
 from ..ast.expressions import PsExpression
@@ -12,55 +12,33 @@ from ..kernelcreation.iteration_space import (
 from .platform import Platform
 from ..exceptions import MaterializationError
 
+if TYPE_CHECKING:
+    from ...codegen.kernel import GpuThreadsRange
 
-class GpuThreadsRange:
-    """Number of threads required by a GPU kernel, in order (x, y, z)."""
 
-    @staticmethod
-    def from_ispace(ispace: IterationSpace) -> GpuThreadsRange:
+class GenericGpu(Platform):
+    @abstractmethod
+    def materialize_iteration_space(
+        self, block: PsBlock, ispace: IterationSpace
+    ) -> tuple[PsBlock, GpuThreadsRange | None]:
+        pass
+
+    @classmethod
+    def threads_from_ispace(cls, ispace: IterationSpace) -> GpuThreadsRange:
+        from ...codegen.kernel import GpuThreadsRange
+
         if isinstance(ispace, FullIterationSpace):
-            return GpuThreadsRange._from_full_ispace(ispace)
+            return cls._threads_from_full_ispace(ispace)
         elif isinstance(ispace, SparseIterationSpace):
             work_items = (PsExpression.make(ispace.index_list.shape[0]),)
             return GpuThreadsRange(work_items)
         else:
             assert False
 
-    def __init__(
-        self,
-        num_work_items: Sequence[PsExpression],
-    ):
-        self._dim = len(num_work_items)
-        self._num_work_items = tuple(num_work_items)
-
-    # @property
-    # def grid_size(self) -> tuple[PsExpression, ...]:
-    #     return self._grid_size
-
-    # @property
-    # def block_size(self) -> tuple[PsExpression, ...]:
-    #     return self._block_size
-
-    @property
-    def num_work_items(self) -> tuple[PsExpression, ...]:
-        """Number of work items in (x, y, z)-order."""
-        return self._num_work_items
-
-    @property
-    def dim(self) -> int:
-        return self._dim
-    
-    def __str__(self) -> str:
-        rep = "GpuThreadsRange { "
-        rep += "; ".join(f"{x}: {w}" for x, w in zip("xyz", self._num_work_items))
-        rep += " }"
-        return rep
-    
-    def _repr_html_(self) -> str:
-        return str(self)
-
-    @staticmethod
-    def _from_full_ispace(ispace: FullIterationSpace) -> GpuThreadsRange:
+    @classmethod
+    def _threads_from_full_ispace(cls, ispace: FullIterationSpace) -> GpuThreadsRange:
+        from ...codegen.kernel import GpuThreadsRange
+        
         dimensions = ispace.dimensions_in_loop_order()[::-1]
         if len(dimensions) > 3:
             raise NotImplementedError(
@@ -81,11 +59,3 @@ class GpuThreadsRange:
 
         work_items = [ispace.actual_iterations(dim) for dim in dimensions]
         return GpuThreadsRange(work_items)
-
-
-class GenericGpu(Platform):
-    @abstractmethod
-    def materialize_iteration_space(
-        self, block: PsBlock, ispace: IterationSpace
-    ) -> tuple[PsBlock, GpuThreadsRange | None]:
-        pass
diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py
index 56615af24e29ea0e65c18b612a6e7652d1b69735..9c04d6074b4feb0e63deddeb5a94cf11d920a0c0 100644
--- a/src/pystencils/backend/platforms/sycl.py
+++ b/src/pystencils/backend/platforms/sycl.py
@@ -1,3 +1,4 @@
+from __future__ import annotations
 from typing import TYPE_CHECKING
 
 from ..functions import CFunction, PsMathFunction, MathFunctions
@@ -24,12 +25,12 @@ from ..extensions.cpp import CppMethodCall
 
 from ..kernelcreation import KernelCreationContext, AstFactory
 from ..constants import PsConstant
-from .generic_gpu import GenericGpu, GpuThreadsRange
+from .generic_gpu import GenericGpu
 from ..exceptions import MaterializationError
 from ...types import PsCustomType, PsIeeeFloatType, constify, PsIntegerType
 
 if TYPE_CHECKING:
-    from ...codegen.config import GpuIndexingConfig
+    from ...codegen import GpuIndexingConfig, GpuThreadsRange
 
 
 class SyclPlatform(GenericGpu):
@@ -38,6 +39,9 @@ class SyclPlatform(GenericGpu):
         self, ctx: KernelCreationContext, indexing_cfg: GpuIndexingConfig | None = None
     ):
         super().__init__(ctx)
+
+        from ...codegen.config import GpuIndexingConfig
+
         self._cfg = indexing_cfg if indexing_cfg is not None else GpuIndexingConfig()
 
     @property
@@ -113,7 +117,7 @@ class SyclPlatform(GenericGpu):
         id_decl = self._id_declaration(rank, id_symbol)
 
         dimensions = ispace.dimensions_in_loop_order()
-        launch_config = GpuThreadsRange.from_ispace(ispace)
+        launch_config = self.threads_from_ispace(ispace)
 
         indexing_decls = [id_decl]
         conds = []
@@ -188,7 +192,7 @@ class SyclPlatform(GenericGpu):
             body.statements = [sparse_idx_decl] + body.statements
             ast = body
 
-        return ast, GpuThreadsRange.from_ispace(ispace)
+        return ast, self.threads_from_ispace(ispace)
 
     def _item_type(self, rank: int):
         if not self._cfg.sycl_automatic_block_size:
diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py
index 47c00881982b994d6f1ed2b650a6de119dd5e24c..78e721f3850e0075a8079131b84ae558abb50062 100644
--- a/src/pystencils/backend/transformations/add_pragmas.py
+++ b/src/pystencils/backend/transformations/add_pragmas.py
@@ -1,3 +1,4 @@
+from __future__ import annotations
 from dataclasses import dataclass
 from typing import TYPE_CHECKING
 
diff --git a/src/pystencils/backend/transformations/canonicalize_symbols.py b/src/pystencils/backend/transformations/canonicalize_symbols.py
index f5b356432a56cc8c2a33eba6ad533947b9f2b2ad..c0406c25d820df0a1c3821074395b8709b482113 100644
--- a/src/pystencils/backend/transformations/canonicalize_symbols.py
+++ b/src/pystencils/backend/transformations/canonicalize_symbols.py
@@ -72,7 +72,6 @@ class CanonicalizeSymbols:
                     symb.dtype = constify(symb.dtype)
 
         #   Any symbols still alive now are function params or globals
-        #   Might use that to populate KernelFunction
         self._last_result = cc
 
         return node
diff --git a/src/pystencils/codegen/__init__.py b/src/pystencils/codegen/__init__.py
index be9fd9510877bb4096c89944f7067b6f342e6fa9..86f7f29400ee7991ab8faa19cdbe18bd562e16f4 100644
--- a/src/pystencils/codegen/__init__.py
+++ b/src/pystencils/codegen/__init__.py
@@ -6,8 +6,8 @@ from .config import (
     OpenMpConfig,
     GpuIndexingConfig,
 )
-
-from .kernel import Kernel
+from .parameters import Parameter
+from .kernel import Kernel, GpuKernel, GpuThreadsRange
 from .driver import create_kernel, get_driver
 
 __all__ = [
@@ -17,7 +17,10 @@ __all__ = [
     "VectorizationConfig",
     "OpenMpConfig",
     "GpuIndexingConfig",
+    "Parameter",
     "Kernel",
+    "GpuKernel",
+    "GpuThreadsRange",
     "create_kernel",
     "get_driver",
-]
\ No newline at end of file
+]
diff --git a/src/pystencils/codegen/config.py b/src/pystencils/codegen/config.py
index 05e3ec3de90e33565e4a3b05e71e06c95f575f2e..b516245fa9b001288857c2f6f48443b223f4112b 100644
--- a/src/pystencils/codegen/config.py
+++ b/src/pystencils/codegen/config.py
@@ -21,15 +21,14 @@ from ..types import (
 from ..defaults import DEFAULTS
 
 if TYPE_CHECKING:
-    from ..backend.jit import JitBase
+    from ..jit import JitBase
 
 
 class PsOptionsError(Exception):
     """Indicates an option clash in the `CreateKernelConfig`."""
 
 
-class _AUTO_TYPE:
-    ...
+class _AUTO_TYPE: ...  # noqa: E701
 
 
 AUTO = _AUTO_TYPE()
@@ -336,12 +335,12 @@ class CreateKernelConfig:
         """Returns either the user-specified JIT compiler, or infers one from the target if none is given."""
         if self.jit is None:
             if self.target.is_cpu():
-                from ..backend.jit import LegacyCpuJit
+                from ..jit import LegacyCpuJit
 
                 return LegacyCpuJit()
             elif self.target == Target.CUDA:
                 try:
-                    from ..backend.jit.gpu_cupy import CupyJit
+                    from ..jit.gpu_cupy import CupyJit
 
                     if (
                         self.gpu_indexing is not None
@@ -352,12 +351,12 @@ class CreateKernelConfig:
                         return CupyJit()
 
                 except ImportError:
-                    from ..backend.jit import no_jit
+                    from ..jit import no_jit
 
                     return no_jit
 
             elif self.target == Target.SYCL:
-                from ..backend.jit import no_jit
+                from ..jit import no_jit
 
                 return no_jit
             else:
diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py
index bc690a598000ca2227fb9ca8ed49f8c0a40f29de..0fd49b248fd4e260546bea9b1727f93a4c951463 100644
--- a/src/pystencils/codegen/driver.py
+++ b/src/pystencils/codegen/driver.py
@@ -1,15 +1,19 @@
 from __future__ import annotations
-
-from typing import cast, Sequence
+from typing import cast, Sequence, Iterable, TYPE_CHECKING
 from dataclasses import dataclass, replace
 
 from .target import Target
 from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO
-from .kernel import Kernel
+from .kernel import Kernel, GpuKernel, GpuThreadsRange
+from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr
+from .parameters import Parameter
 
 from ..types import create_numeric_type, PsIntegerType, PsScalarType
+
+from ..backend.memory import PsSymbol
 from ..backend.ast import PsAstNode
 from ..backend.ast.structural import PsBlock, PsLoop
+from ..backend.ast.analysis import collect_undefined_symbols, collect_required_headers
 from ..backend.kernelcreation import (
     KernelCreationContext,
     KernelAnalysis,
@@ -22,7 +26,12 @@ from ..backend.kernelcreation.iteration_space import (
     create_full_iteration_space,
     FullIterationSpace,
 )
-from ..backend.platforms import Platform, GenericCpu, GenericVectorCpu, GenericGpu
+from ..backend.platforms import (
+    Platform,
+    GenericCpu,
+    GenericVectorCpu,
+    GenericGpu,
+)
 from ..backend.exceptions import VectorizationError
 
 from ..backend.transformations import (
@@ -36,6 +45,9 @@ from ..backend.transformations import (
 from ..simp import AssignmentCollection
 from sympy.codegen.ast import AssignmentBase
 
+if TYPE_CHECKING:
+    from ..jit import JitBase
+
 
 __all__ = ["create_kernel"]
 
@@ -238,7 +250,7 @@ class DefaultKernelCreationDriver:
         kernel_ast = self._vectorize(kernel_ast)
 
         if cpu_cfg.openmp is not False:
-            from .backend.transformations import AddOpenMP
+            from ..backend.transformations import AddOpenMP
 
             params = (
                 cpu_cfg.openmp
@@ -262,7 +274,7 @@ class DefaultKernelCreationDriver:
         if vec_config is None:
             return kernel_ast
 
-        from .backend.transformations import LoopVectorizer, SelectIntrinsics
+        from ..backend.transformations import LoopVectorizer, SelectIntrinsics
 
         assert isinstance(self._platform, GenericVectorCpu)
 
@@ -359,6 +371,107 @@ class DefaultKernelCreationDriver:
             f"Code generation for target {self._target} not implemented"
         )
 
+    def _get_function_params(self, symbols: Iterable[PsSymbol]):
+        params: list[Parameter] = []
+
+        from pystencils.backend.memory import BufferBasePtr
+
+        for symb in symbols:
+            props: set[PsSymbolProperty] = set()
+            for prop in symb.properties:
+                match prop:
+                    case FieldShape() | FieldStride():
+                        props.add(prop)
+                    case BufferBasePtr(buf):
+                        field = self._ctx.find_field(buf.name)
+                        props.add(FieldBasePtr(field))
+            params.append(Parameter(symb.name, symb.get_dtype(), props))
+
+        params.sort(key=lambda p: p.name)
+        return params
+
+    def _get_headers(self, body: PsBlock):
+        req_headers = collect_required_headers(body)
+        req_headers |= self._platform.required_headers
+        req_headers |= self._ctx.required_headers
+        return req_headers
+    
+
+def create_cpu_kernel_function(
+    ctx: KernelCreationContext,
+    platform: Platform,
+    body: PsBlock,
+    function_name: str,
+    target_spec: Target,
+    jit: JitBase,
+):
+    undef_symbols = collect_undefined_symbols(body)
+
+    params = _get_function_params(ctx, undef_symbols)
+    req_headers = _get_headers(ctx, platform, body)
+
+    kfunc = Kernel(
+        body, target_spec, function_name, params, req_headers, jit
+    )
+    kfunc.metadata.update(ctx.metadata)
+    return kfunc
+
+
+def create_gpu_kernel_function(
+    ctx: KernelCreationContext,
+    platform: Platform,
+    body: PsBlock,
+    threads_range: GpuThreadsRange,
+    function_name: str,
+    target_spec: Target,
+    jit: JitBase,
+):
+    undef_symbols = collect_undefined_symbols(body)
+    for threads in threads_range.num_work_items:
+        undef_symbols |= collect_undefined_symbols(threads)
+
+    params = _get_function_params(ctx, undef_symbols)
+    req_headers = _get_headers(ctx, platform, body)
+
+    kfunc = GpuKernel(
+        body,
+        threads_range,
+        target_spec,
+        function_name,
+        params,
+        req_headers,
+        jit,
+    )
+    kfunc.metadata.update(ctx.metadata)
+    return kfunc
+
+
+def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]):
+    params: list[Parameter] = []
+
+    from pystencils.backend.memory import BufferBasePtr
+
+    for symb in symbols:
+        props: set[PsSymbolProperty] = set()
+        for prop in symb.properties:
+            match prop:
+                case FieldShape() | FieldStride():
+                    props.add(prop)
+                case BufferBasePtr(buf):
+                    field = ctx.find_field(buf.name)
+                    props.add(FieldBasePtr(field))
+        params.append(Parameter(symb.name, symb.get_dtype(), props))
+
+    params.sort(key=lambda p: p.name)
+    return params
+
+
+def _get_headers(ctx: KernelCreationContext, platform: Platform, body: PsBlock):
+    req_headers = collect_required_headers(body)
+    req_headers |= platform.required_headers
+    req_headers |= ctx.required_headers
+    return req_headers
+
 
 @dataclass
 class StageResult:
diff --git a/src/pystencils/codegen/gpu.py b/src/pystencils/codegen/gpu.py
deleted file mode 100644
index 9cce9b55bbc838e98bc739eb2533e9626894c3a2..0000000000000000000000000000000000000000
--- a/src/pystencils/codegen/gpu.py
+++ /dev/null
@@ -1,28 +0,0 @@
-
-
-from .kernel import Kernel
-
-
-class GpuKernel(Kernel):
-    """Internal representation of a kernel function targeted at CUDA GPUs."""
-
-    def __init__(
-        self,
-        body: PsBlock,
-        threads_range: GpuThreadsRange | None,
-        target: Target,
-        name: str,
-        parameters: Sequence[KernelParameter],
-        required_headers: set[str],
-        constraints: Sequence[KernelParamsConstraint],
-        jit: JitBase,
-    ):
-        super().__init__(
-            body, target, name, parameters, required_headers, constraints, jit
-        )
-        self._threads_range = threads_range
-
-    @property
-    def threads_range(self) -> GpuThreadsRange | None:
-        """Object exposing the total size of the launch grid this kernel expects to be executed with."""
-        return self._threads_range
diff --git a/src/pystencils/codegen/kernel.py b/src/pystencils/codegen/kernel.py
index 6a0a6d57678d8d3cd7d53ffbf123d1ae032ced79..c4ad860b621d65965cd4b3cb62a1c05ad67807ee 100644
--- a/src/pystencils/codegen/kernel.py
+++ b/src/pystencils/codegen/kernel.py
@@ -1,21 +1,19 @@
 from __future__ import annotations
 
 from warnings import warn
-from typing import Callable, Sequence, Iterable, Any, TYPE_CHECKING
+from typing import Callable, Sequence, Any, TYPE_CHECKING
 from itertools import chain
 
-from .._deprecation import _deprecated
-
-from ..backend.ast.structural import PsBlock
-from ..backend.ast.analysis import collect_required_headers, collect_undefined_symbols
-from ..backend.memory import PsSymbol
-
-from ..types import PsType
-
 from .target import Target
 from .parameters import Parameter
+from ..backend.ast.structural import PsBlock
+from ..backend.ast.expressions import PsExpression
 from ..field import Field
-from ..sympyextensions import TypedSymbol
+
+from .._deprecation import _deprecated
+
+if TYPE_CHECKING:
+    from ..jit import JitBase
 
 
 class Kernel:
@@ -23,7 +21,7 @@ class Kernel:
 
     The kernel object is the final result of the translation process.
     It is immutable, and its AST should not be altered any more, either, as this
-    might invalidate information about the kernel already stored in the `KernelFunction` object.
+    might invalidate information about the kernel already stored in the kernel object.
     """
 
     def __init__(
@@ -78,7 +76,7 @@ class Kernel:
         return self._params
 
     def get_parameters(self) -> tuple[Parameter, ...]:
-        _deprecated("KernelFunction.get_parameters", "KernelFunction.parameters")
+        _deprecated("Kernel.get_parameters", "Kernel.parameters")
         return self.parameters
 
     def get_fields(self) -> set[Field]:
@@ -97,6 +95,77 @@ class Kernel:
     def required_headers(self) -> set[str]:
         return self._required_headers
 
+    def get_c_code(self) -> str:
+        from ..backend.emission import CAstPrinter
+
+        printer = CAstPrinter()
+        return printer(self)
+
+    def get_ir_code(self) -> str:
+        from ..backend.emission import IRAstPrinter
+
+        printer = IRAstPrinter()
+        return printer(self)
+
     def compile(self) -> Callable[..., None]:
         """Invoke the underlying just-in-time compiler to obtain the kernel as an executable Python function."""
         return self._jit.compile(self)
+
+
+class GpuKernel(Kernel):
+    """Internal representation of a kernel function targeted at CUDA GPUs."""
+
+    def __init__(
+        self,
+        body: PsBlock,
+        threads_range: GpuThreadsRange | None,
+        target: Target,
+        name: str,
+        parameters: Sequence[Parameter],
+        required_headers: set[str],
+        jit: JitBase,
+    ):
+        super().__init__(body, target, name, parameters, required_headers, jit)
+        self._threads_range = threads_range
+
+    @property
+    def threads_range(self) -> GpuThreadsRange | None:
+        """Object exposing the total size of the launch grid this kernel expects to be executed with."""
+        return self._threads_range
+
+
+class GpuThreadsRange:
+    """Number of threads required by a GPU kernel, in order (x, y, z)."""
+
+    def __init__(
+        self,
+        num_work_items: Sequence[PsExpression],
+    ):
+        self._dim = len(num_work_items)
+        self._num_work_items = tuple(num_work_items)
+
+    # @property
+    # def grid_size(self) -> tuple[PsExpression, ...]:
+    #     return self._grid_size
+
+    # @property
+    # def block_size(self) -> tuple[PsExpression, ...]:
+    #     return self._block_size
+
+    @property
+    def num_work_items(self) -> tuple[PsExpression, ...]:
+        """Number of work items in (x, y, z)-order."""
+        return self._num_work_items
+
+    @property
+    def dim(self) -> int:
+        return self._dim
+
+    def __str__(self) -> str:
+        rep = "GpuThreadsRange { "
+        rep += "; ".join(f"{x}: {w}" for x, w in zip("xyz", self._num_work_items))
+        rep += " }"
+        return rep
+
+    def _repr_html_(self) -> str:
+        return str(self)
diff --git a/src/pystencils/codegen/parameters.py b/src/pystencils/codegen/parameters.py
index 1e01e07aa5c3c1f9c675b8b30baa09f75290984e..d40eae220c94efea9785ec561e10b217ccb88c91 100644
--- a/src/pystencils/codegen/parameters.py
+++ b/src/pystencils/codegen/parameters.py
@@ -1,14 +1,8 @@
 from __future__ import annotations
 
 from warnings import warn
-from typing import Callable, Sequence, Iterable, Any, TYPE_CHECKING
-from itertools import chain
+from typing import Sequence, Iterable
 
-from .._deprecation import _deprecated
-
-from ..backend.ast.structural import PsBlock
-from ..backend.ast.analysis import collect_required_headers, collect_undefined_symbols
-from ..backend.memory import PsSymbol
 from .properties import (
     PsSymbolProperty,
     _FieldProperty,
@@ -16,16 +10,13 @@ from .properties import (
     FieldStride,
     FieldBasePtr,
 )
-
 from ..types import PsType
-
-from .target import Target
 from ..field import Field
 from ..sympyextensions import TypedSymbol
 
 
 class Parameter:
-    """Parameter to a `KernelFunction`."""
+    """Parameter to an output object of the code generator."""
 
     __match_args__ = ("name", "dtype", "properties")
 
@@ -45,7 +36,7 @@ class Parameter:
                         lambda p: isinstance(p, _FieldProperty), self._properties
                     )
                 ),
-                key=lambda f: f.name
+                key=lambda f: f.name,
             )
         )
 
@@ -139,4 +130,4 @@ class Parameter:
             "Use `param.fields[0].name` instead.",
             DeprecationWarning,
         )
-        return self._fields[0].name
\ No newline at end of file
+        return self._fields[0].name
diff --git a/src/pystencils/display_utils.py b/src/pystencils/display_utils.py
index 7f110c9c06f97fd37f17e734f5501f856216e56f..919dea4a8b568143065e8361fc695a044c69d541 100644
--- a/src/pystencils/display_utils.py
+++ b/src/pystencils/display_utils.py
@@ -2,9 +2,8 @@ from typing import Any, Dict, Optional
 
 import sympy as sp
 
-from pystencils.backend import KernelFunction
-from pystencils.kernel_wrapper import KernelWrapper as OldKernelWrapper
-from .backend.jit import KernelWrapper
+from .codegen import Kernel
+from .jit import KernelWrapper
 
 
 def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=True):
@@ -43,32 +42,27 @@ def highlight_cpp(code: str):
     return HTML(highlight(code, CppLexer(), HtmlFormatter()))
 
 
-def get_code_obj(ast: KernelWrapper | KernelFunction, custom_backend=None):
+def get_code_obj(ast: KernelWrapper | Kernel, custom_backend=None):
     """Returns an object to display generated code (C/C++ or CUDA)
 
     Can either be displayed as HTML in Jupyter notebooks or printed as normal string.
     """
-    from pystencils.backend.emission import emit_code
-
-    if isinstance(ast, OldKernelWrapper):
-        ast = ast.ast
-    elif isinstance(ast, KernelWrapper):
-        ast = ast.kernel_function
+    if isinstance(ast, KernelWrapper):
+        func = ast.kernel_function
+    else:
+        func = ast
 
     class CodeDisplay:
-        def __init__(self, ast_input):
-            self.ast = ast_input
-
         def _repr_html_(self):
-            return highlight_cpp(emit_code(self.ast)).__html__()
+            return highlight_cpp(func.get_c_code()).__html__()
 
         def __str__(self):
-            return emit_code(self.ast)
+            return func.get_c_code()
 
         def __repr__(self):
-            return emit_code(self.ast)
+            return func.get_c_code()
 
-    return CodeDisplay(ast)
+    return CodeDisplay()
 
 
 def get_code_str(ast, custom_backend=None):
@@ -88,7 +82,7 @@ def _isnotebook():
         return False
 
 
-def show_code(ast: KernelWrapper | KernelFunction, custom_backend=None):
+def show_code(ast: KernelWrapper | Kernel, custom_backend=None):
     code = get_code_obj(ast, custom_backend)
 
     if _isnotebook():
diff --git a/src/pystencils/inspection.py b/src/pystencils/inspection.py
index 7fa3047c62d390bf55edf96e2f47fd2e6e967f23..7f050c745d34cad374644109ca527db592822db4 100644
--- a/src/pystencils/inspection.py
+++ b/src/pystencils/inspection.py
@@ -2,8 +2,8 @@ from typing import overload
 
 from .backend.ast import PsAstNode
 from .backend.emission import CAstPrinter, IRAstPrinter, EmissionError
-from .backend.kernelfunction import KernelFunction
-from .kernelcreation import StageResult, CodegenIntermediates
+from .codegen import Kernel
+from .codegen.driver import StageResult, CodegenIntermediates
 from abc import ABC, abstractmethod
 
 _UNABLE_TO_DISPLAY_CPP = """
@@ -37,7 +37,7 @@ class CodeInspectionBase(ABC):
         self._ir_printer = IRAstPrinter(annotate_constants=False)
         self._c_printer = CAstPrinter()
 
-    def _ir_tab(self, ir_obj: PsAstNode | KernelFunction):
+    def _ir_tab(self, ir_obj: PsAstNode | Kernel):
         import ipywidgets as widgets
 
         ir = self._ir_printer(ir_obj)
@@ -45,7 +45,7 @@ class CodeInspectionBase(ABC):
         self._apply_tab_layout(ir_tab)
         return ir_tab
 
-    def _cpp_tab(self, ir_obj: PsAstNode | KernelFunction):
+    def _cpp_tab(self, ir_obj: PsAstNode | Kernel):
         import ipywidgets as widgets
 
         try:
@@ -64,7 +64,7 @@ class CodeInspectionBase(ABC):
         self._apply_tab_layout(cpp_tab)
         return cpp_tab
 
-    def _graphviz_tab(self, ir_obj: PsAstNode | KernelFunction):
+    def _graphviz_tab(self, ir_obj: PsAstNode | Kernel):
         import ipywidgets as widgets
 
         graphviz_tab = widgets.HTML(_GRAPHVIZ_NOT_IMPLEMENTED)
@@ -124,7 +124,7 @@ class AstInspection(CodeInspectionBase):
 
 
 class KernelInspection(CodeInspectionBase):
-    def __init__(self, kernel: KernelFunction) -> None:
+    def __init__(self, kernel: Kernel) -> None:
         super().__init__()
         self._kernel = kernel
 
@@ -190,7 +190,7 @@ def inspect(obj: PsAstNode): ...
 
 
 @overload
-def inspect(obj: KernelFunction): ...
+def inspect(obj: Kernel): ...
 
 
 @overload
@@ -207,7 +207,7 @@ def inspect(obj):
     When run inside a Jupyter notebook, this function displays an inspection widget
     for the following types of objects:
     - `PsAstNode`
-    - `KernelFunction`
+    - `Kernel`
     - `StageResult`
     - `CodegenIntermediates`
     """
@@ -217,7 +217,7 @@ def inspect(obj):
     match obj:
         case PsAstNode():
             preview = AstInspection(obj)
-        case KernelFunction():
+        case Kernel():
             preview = KernelInspection(obj)
         case StageResult(ast, _):
             preview = AstInspection(ast)
diff --git a/src/pystencils/jit/__init__.py b/src/pystencils/jit/__init__.py
index f45cb9bff09d07fd85fd59957bf3582a3eb7f80f..a47dc4aa6e1e0be03949c9fb48854ae385a41526 100644
--- a/src/pystencils/jit/__init__.py
+++ b/src/pystencils/jit/__init__.py
@@ -2,7 +2,7 @@
 JIT compilation is realized by subclasses of `JitBase`.
 A JIT compiler may freely be created and configured by the user.
 It can then be passed to `create_kernel` using the ``jit`` argument of
-`CreateKernelConfig`, in which case it is hooked into the `KernelFunction.compile` method
+`CreateKernelConfig`, in which case it is hooked into the `Kernel.compile` method
 of the generated kernel function::
 
     my_jit = MyJit()
diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py
index 444167f9d06f281533b2b51c61b736c2924c6118..befb033e6f7969a5ffd9bc7742e9e7ab691da47d 100644
--- a/src/pystencils/jit/cpu_extension_module.py
+++ b/src/pystencils/jit/cpu_extension_module.py
@@ -9,22 +9,19 @@ from textwrap import indent
 
 import numpy as np
 
-from ..exceptions import PsInternalCompilerError
-from ..kernelfunction import (
-    KernelFunction,
-    KernelParameter,
+from ..codegen import (
+    Kernel,
+    Parameter,
 )
-from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride
-from ..constraints import KernelParamsConstraint
-from ...types import (
+from ..codegen.properties import FieldBasePtr, FieldShape, FieldStride
+from ..types import (
     PsType,
     PsUnsignedIntegerType,
     PsSignedIntegerType,
     PsIeeeFloatType,
 )
-from ...types.quick import Fp, SInt, UInt
-from ...field import Field
-from ..emission import emit_code
+from ..types.quick import Fp, SInt, UInt
+from ..field import Field
 
 
 class PsKernelExtensioNModule:
@@ -38,11 +35,11 @@ class PsKernelExtensioNModule:
         self._module_name = module_name
 
         if custom_backend is not None:
-            raise PsInternalCompilerError(
+            raise Exception(
                 "The `custom_backend` parameter exists only for interface compatibility and cannot be set."
             )
 
-        self._kernels: dict[str, KernelFunction] = dict()
+        self._kernels: dict[str, Kernel] = dict()
         self._code_string: str | None = None
         self._code_hash: str | None = None
 
@@ -50,7 +47,7 @@ class PsKernelExtensioNModule:
     def module_name(self) -> str:
         return self._module_name
 
-    def add_function(self, kernel_function: KernelFunction, name: str | None = None):
+    def add_function(self, kernel_function: Kernel, name: str | None = None):
         if name is None:
             name = kernel_function.name
 
@@ -98,7 +95,7 @@ class PsKernelExtensioNModule:
             old_name = kernel.name
             kernel.name = f"kernel_{name}"
 
-            code += emit_code(kernel)
+            code += kernel.get_c_code()
             code += "\n"
             code += emit_call_wrapper(name, kernel)
             code += "\n"
@@ -122,14 +119,14 @@ class PsKernelExtensioNModule:
         print(self._code_string, file=file)
 
 
-def emit_call_wrapper(function_name: str, kernel: KernelFunction) -> str:
+def emit_call_wrapper(function_name: str, kernel: Kernel) -> str:
     builder = CallWrapperBuilder()
 
     for p in kernel.parameters:
         builder.extract_parameter(p)
 
-    for c in kernel.constraints:
-        builder.check_constraint(c)
+    # for c in kernel.constraints:
+    #     builder.check_constraint(c)
 
     builder.call(kernel, kernel.parameters)
 
@@ -206,8 +203,8 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
         self._array_extractions: dict[Field, str] = dict()
         self._array_frees: dict[Field, str] = dict()
 
-        self._array_assoc_var_extractions: dict[KernelParameter, str] = dict()
-        self._scalar_extractions: dict[KernelParameter, str] = dict()
+        self._array_assoc_var_extractions: dict[Parameter, str] = dict()
+        self._scalar_extractions: dict[Parameter, str] = dict()
 
         self._constraint_checks: list[str] = []
 
@@ -223,7 +220,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
                 return "PyLong_AsUnsignedLong"
 
             case _:
-                raise PsInternalCompilerError(
+                raise ValueError(
                     f"Don't know how to cast Python objects to {dtype}"
                 )
 
@@ -267,7 +264,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
 
         return self._array_buffers[field]
 
-    def extract_scalar(self, param: KernelParameter) -> str:
+    def extract_scalar(self, param: Parameter) -> str:
         if param not in self._scalar_extractions:
             extract_func = self._scalar_extractor(param.dtype)
             code = self.TMPL_EXTRACT_SCALAR.format(
@@ -279,7 +276,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
 
         return param.name
 
-    def extract_array_assoc_var(self, param: KernelParameter) -> str:
+    def extract_array_assoc_var(self, param: Parameter) -> str:
         if param not in self._array_assoc_var_extractions:
             field = param.fields[0]
             buffer = self.extract_field(field)
@@ -305,31 +302,31 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
 
         return param.name
 
-    def extract_parameter(self, param: KernelParameter):
+    def extract_parameter(self, param: Parameter):
         if param.is_field_parameter:
             self.extract_array_assoc_var(param)
         else:
             self.extract_scalar(param)
 
-    def check_constraint(self, constraint: KernelParamsConstraint):
-        variables = constraint.get_parameters()
+#     def check_constraint(self, constraint: KernelParamsConstraint):
+#         variables = constraint.get_parameters()
 
-        for var in variables:
-            self.extract_parameter(var)
+#         for var in variables:
+#             self.extract_parameter(var)
 
-        cond = constraint.to_code()
+#         cond = constraint.to_code()
 
-        code = f"""
-if(!({cond}))
-{{
-    PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}"); 
-    return NULL;
-}}
-"""
+#         code = f"""
+# if(!({cond}))
+# {{
+#     PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}"); 
+#     return NULL;
+# }}
+# """
 
-        self._constraint_checks.append(code)
+#         self._constraint_checks.append(code)
 
-    def call(self, kernel: KernelFunction, params: tuple[KernelParameter, ...]):
+    def call(self, kernel: Kernel, params: tuple[Parameter, ...]):
         param_list = ", ".join(p.name for p in params)
         self._call = f"{kernel.name} ({param_list});"
 
diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py
index 2f5753e0528f75c2a01e34a44d06b2d10f96de3d..c208ac2196151d079ca5081f1377c55d18a9393c 100644
--- a/src/pystencils/jit/gpu_cupy.py
+++ b/src/pystencils/jit/gpu_cupy.py
@@ -8,21 +8,20 @@ try:
 except ImportError:
     HAVE_CUPY = False
 
-from ...codegen import Target
-from ...field import FieldType
+from ..codegen import Target
+from ..field import FieldType
 
-from ...types import PsType
+from ..types import PsType
 from .jit import JitBase, JitError, KernelWrapper
-from ..kernelfunction import (
-    KernelFunction,
-    GpuKernelFunction,
-    KernelParameter,
+from ..codegen import (
+    Kernel,
+    GpuKernel,
+    Parameter,
 )
-from ...codegen.properties import FieldShape, FieldStride, FieldBasePtr
-from ..emission import emit_code
-from ...types import PsStructType
+from ..codegen.properties import FieldShape, FieldStride, FieldBasePtr
+from ..types import PsStructType
 
-from ...include import get_pystencils_include_path
+from ..include import get_pystencils_include_path
 
 
 @dataclass
@@ -34,18 +33,18 @@ class LaunchGrid:
 class CupyKernelWrapper(KernelWrapper):
     def __init__(
         self,
-        kfunc: GpuKernelFunction,
+        kfunc: GpuKernel,
         raw_kernel: Any,
         block_size: tuple[int, int, int],
     ):
-        self._kfunc: GpuKernelFunction = kfunc
+        self._kfunc: GpuKernel = kfunc
         self._raw_kernel = raw_kernel
         self._block_size = block_size
         self._num_blocks: tuple[int, int, int] | None = None
         self._args_cache: dict[Any, tuple] = dict()
 
     @property
-    def kernel_function(self) -> GpuKernelFunction:
+    def kernel_function(self) -> GpuKernel:
         return self._kfunc
 
     @property
@@ -105,7 +104,7 @@ class CupyKernelWrapper(KernelWrapper):
         field_shapes = set()
         index_shapes = set()
 
-        def check_shape(field_ptr: KernelParameter, arr: cp.ndarray):
+        def check_shape(field_ptr: Parameter, arr: cp.ndarray):
             field = field_ptr.fields[0]
 
             if field.has_fixed_shape:
@@ -190,7 +189,7 @@ class CupyKernelWrapper(KernelWrapper):
                 add_arg(kparam.name, val, kparam.dtype)
 
         #   Determine launch grid
-        from ..ast.expressions import evaluate_expression
+        from ..backend.ast.expressions import evaluate_expression
 
         symbolic_threads_range = self._kfunc.threads_range
 
@@ -243,13 +242,13 @@ class CupyJit(JitBase):
             tuple(default_block_size) + (1,) * (3 - len(default_block_size)),
         )
 
-    def compile(self, kfunc: KernelFunction) -> KernelWrapper:
+    def compile(self, kfunc: Kernel) -> KernelWrapper:
         if not HAVE_CUPY:
             raise JitError(
                 "`cupy` is not installed: just-in-time-compilation of CUDA kernels is unavailable."
             )
 
-        if not isinstance(kfunc, GpuKernelFunction) or kfunc.target != Target.CUDA:
+        if not isinstance(kfunc, GpuKernel) or kfunc.target != Target.CUDA:
             raise ValueError(
                 "The CupyJit just-in-time compiler only accepts kernels generated for CUDA or HIP"
             )
@@ -269,7 +268,7 @@ class CupyJit(JitBase):
         options.append("-I" + get_pystencils_include_path())
         return tuple(options)
 
-    def _prelude(self, kfunc: GpuKernelFunction) -> str:
+    def _prelude(self, kfunc: GpuKernel) -> str:
         headers = self._runtime_headers
         headers |= kfunc.required_headers
 
@@ -286,6 +285,6 @@ class CupyJit(JitBase):
 
         return code
 
-    def _kernel_code(self, kfunc: GpuKernelFunction) -> str:
-        kernel_code = emit_code(kfunc)
+    def _kernel_code(self, kfunc: GpuKernel) -> str:
+        kernel_code = kfunc.get_c_code()
         return f'extern "C" {kernel_code}'
diff --git a/src/pystencils/jit/jit.py b/src/pystencils/jit/jit.py
index 250bba2401ff94e97e8d94f91478d3ad28ec9395..4998c14adfdc810a93d1a1f96cc310ac81c65f5d 100644
--- a/src/pystencils/jit/jit.py
+++ b/src/pystencils/jit/jit.py
@@ -3,8 +3,7 @@ from typing import Sequence, TYPE_CHECKING
 from abc import ABC, abstractmethod
 
 if TYPE_CHECKING:
-    from ..kernelfunction import KernelFunction, KernelParameter
-    from ...codegen.target import Target
+    from ..codegen import Kernel, Parameter, Target
 
 
 class JitError(Exception):
@@ -14,7 +13,7 @@ class JitError(Exception):
 class KernelWrapper(ABC):
     """Wrapper around a compiled and executable pystencils kernel."""
 
-    def __init__(self, kfunc: KernelFunction) -> None:
+    def __init__(self, kfunc: Kernel) -> None:
         self._kfunc = kfunc
 
     @abstractmethod
@@ -22,11 +21,11 @@ class KernelWrapper(ABC):
         pass
 
     @property
-    def kernel_function(self) -> KernelFunction:
+    def kernel_function(self) -> Kernel:
         return self._kfunc
     
     @property
-    def ast(self) -> KernelFunction:
+    def ast(self) -> Kernel:
         return self._kfunc
     
     @property
@@ -34,7 +33,7 @@ class KernelWrapper(ABC):
         return self._kfunc.target
     
     @property
-    def parameters(self) -> Sequence[KernelParameter]:
+    def parameters(self) -> Sequence[Parameter]:
         return self._kfunc.parameters
 
     @property
@@ -48,14 +47,14 @@ class JitBase(ABC):
     """Base class for just-in-time compilation interfaces implemented in pystencils."""
 
     @abstractmethod
-    def compile(self, kernel: KernelFunction) -> KernelWrapper:
+    def compile(self, kernel: Kernel) -> KernelWrapper:
         """Compile a kernel function and return a callable object which invokes the kernel."""
 
 
 class NoJit(JitBase):
     """Not a JIT compiler: Used to explicitly disable JIT compilation on an AST."""
 
-    def compile(self, kernel: KernelFunction) -> KernelWrapper:
+    def compile(self, kernel: Kernel) -> KernelWrapper:
         raise JitError(
             "Just-in-time compilation of this kernel was explicitly disabled."
         )
diff --git a/src/pystencils/jit/legacy_cpu.py b/src/pystencils/jit/legacy_cpu.py
index 1acd1b22ad48ac0564d255314bd6603405421fdc..514e9b60e4a5ae83a234be9f3cd514fdc7a0e555 100644
--- a/src/pystencils/jit/legacy_cpu.py
+++ b/src/pystencils/jit/legacy_cpu.py
@@ -61,7 +61,7 @@ import time
 import warnings
 
 
-from ..kernelfunction import KernelFunction
+from ..codegen import Kernel
 from .jit import JitBase, KernelWrapper
 from .cpu_extension_module import PsKernelExtensioNModule
 
@@ -71,7 +71,7 @@ from pystencils.utils import atomic_file_write, recursive_dict_update
 
 
 class CpuKernelWrapper(KernelWrapper):
-    def __init__(self, kfunc: KernelFunction, compiled_kernel: Callable[..., None]) -> None:
+    def __init__(self, kfunc: Kernel, compiled_kernel: Callable[..., None]) -> None:
         super().__init__(kfunc)
         self._compiled_kernel = compiled_kernel
 
@@ -86,7 +86,7 @@ class CpuKernelWrapper(KernelWrapper):
 class LegacyCpuJit(JitBase):
     """Wrapper around ``pystencils.cpu.cpujit``"""
 
-    def compile(self, kernel: KernelFunction) -> KernelWrapper:
+    def compile(self, kernel: Kernel) -> KernelWrapper:
         return compile_and_load(kernel)
 
 
@@ -436,7 +436,7 @@ def compile_module(code, code_hash, base_dir, compile_flags=None):
     return lib_file
 
 
-def compile_and_load(kernel: KernelFunction, custom_backend=None):
+def compile_and_load(kernel: Kernel, custom_backend=None):
     cache_config = get_cache_config()
 
     compiler_config = get_compiler_config()
diff --git a/src/pystencils/kernel_wrapper.py b/src/pystencils/kernel_wrapper.py
index afce06d77a17e0eb067d84a02bc273ba0668fc55..5095332c18fa4526fc0b7fb37aad80bc6dc18452 100644
--- a/src/pystencils/kernel_wrapper.py
+++ b/src/pystencils/kernel_wrapper.py
@@ -1,3 +1,3 @@
-from .backend.jit import KernelWrapper as _KernelWrapper
+from .jit import KernelWrapper as _KernelWrapper
 
 KernelWrapper = _KernelWrapper
diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py
index 9bf3eaf6756831e0afcaa6650e518ff36101bd65..97965f709fa092ff95f908a4dc721a6a76ec8e95 100644
--- a/src/pystencils/kernelcreation.py
+++ b/src/pystencils/kernelcreation.py
@@ -4,9 +4,10 @@ from .codegen import create_kernel as _create_kernel
 from warnings import warn
 
 warn(
-    "Importing anything from `pystencils.kernelcreation` is deprecated and the module will be removed in pystencils 2.1. "
+    "Importing anything from `pystencils.kernelcreation` is deprecated "
+    "and the module will be removed in pystencils 2.1. "
     "Import from `pystencils` instead.",
-    FutureWarning
+    FutureWarning,
 )
 
 
@@ -19,4 +20,3 @@ def create_staggered_kernel(
     raise NotImplementedError(
         "Staggered kernels are not yet implemented for pystencils 2.0"
     )
-
diff --git a/tests/kernelcreation/test_domain_kernels.py b/tests/kernelcreation/test_domain_kernels.py
index d02bfd8e46e8fc8c8f19bcebffc0db52787ff1bd..da261faec49940df31d59f44651956e2012b113a 100644
--- a/tests/kernelcreation/test_domain_kernels.py
+++ b/tests/kernelcreation/test_domain_kernels.py
@@ -10,16 +10,14 @@ from pystencils import (
     AssignmentCollection,
     Target,
     CreateKernelConfig,
-    CpuOptimConfig,
-    VectorizationConfig,
 )
 from pystencils.assignment import assignment_from_stencil
 
-from pystencils.kernelcreation import create_kernel, KernelFunction
+from pystencils import create_kernel, Kernel
 from pystencils.backend.emission import emit_code
 
 
-def inspect_dp_kernel(kernel: KernelFunction, gen_config: CreateKernelConfig):
+def inspect_dp_kernel(kernel: Kernel, gen_config: CreateKernelConfig):
     code = emit_code(kernel)
 
     match gen_config.target:
diff --git a/tests/kernelcreation/test_index_kernels.py b/tests/kernelcreation/test_index_kernels.py
index 5093c43ff4f74343b0fcf3f45d34b1cfb6597d05..569c0ab6a0e582de895a66c656697fdf8a5909ee 100644
--- a/tests/kernelcreation/test_index_kernels.py
+++ b/tests/kernelcreation/test_index_kernels.py
@@ -2,7 +2,7 @@ import numpy as np
 import pytest
 
 from pystencils import Assignment, Field, FieldType, AssignmentCollection, Target
-from pystencils.kernelcreation import create_kernel, CreateKernelConfig
+from pystencils import create_kernel, CreateKernelConfig
 
 
 @pytest.mark.parametrize("target", [Target.CPU, Target.GPU])
diff --git a/tests/kernelcreation/test_iteration_slices.py b/tests/kernelcreation/test_iteration_slices.py
index 94ed0295441885cb4ffc87855181659383b1fde7..47f3b5fac2b868df746469f2d1a5c5aabdc12172 100644
--- a/tests/kernelcreation/test_iteration_slices.py
+++ b/tests/kernelcreation/test_iteration_slices.py
@@ -19,7 +19,7 @@ from pystencils import (
 from pystencils.sympyextensions.integer_functions import int_rem
 from pystencils.simp import sympy_cse_on_assignment_list
 from pystencils.slicing import normalize_slice
-from pystencils.backend.jit.gpu_cupy import CupyKernelWrapper
+from pystencils.jit.gpu_cupy import CupyKernelWrapper
 
 
 def test_sliced_iteration():
diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py
index ef4806314eb52c7389bf583027bb808c42049213..109cfdc1914f86a89e150c6a7dc4e9a7bc382cf9 100644
--- a/tests/nbackend/test_code_printing.py
+++ b/tests/nbackend/test_code_printing.py
@@ -1,11 +1,6 @@
-from pystencils import Target
-
 from pystencils.backend.ast.expressions import PsExpression
-from pystencils.backend.ast.structural import PsAssignment, PsLoop, PsBlock
-from pystencils.backend.kernelfunction import KernelFunction
-from pystencils.backend.memory import PsSymbol, PsBuffer
+from pystencils.backend.memory import PsSymbol
 from pystencils.backend.constants import PsConstant
-from pystencils.backend.literals import PsLiteral
 from pystencils.types.quick import Fp, SInt, UInt, Bool
 from pystencils.backend.emission import CAstPrinter
 
@@ -129,7 +124,7 @@ def test_relations_precedence():
 
 def test_ternary():
     from pystencils.backend.ast.expressions import PsTernary
-    from pystencils.backend.ast.expressions import PsNot, PsAnd, PsOr
+    from pystencils.backend.ast.expressions import PsAnd, PsOr
 
     p, q = [PsExpression.make(PsSymbol(x, Bool())) for x in "pq"]
     x, y, z = [PsExpression.make(PsSymbol(x, Fp(32))) for x in "xyz"]
diff --git a/tests/nbackend/test_cpujit.py b/tests/nbackend/test_cpujit.py
index 648112ef95bf5d6c3181f5c3c2527dd870220f0e..c053df9a9e0d381d5f92d129a3b9280a7d56f236 100644
--- a/tests/nbackend/test_cpujit.py
+++ b/tests/nbackend/test_cpujit.py
@@ -1,6 +1,6 @@
 import pytest
 
-from pystencils import Target
+from pystencils import Target, Kernel
 
 # from pystencils.backend.constraints import PsKernelParamsConstraint
 from pystencils.backend.memory import PsSymbol, PsBuffer
@@ -8,10 +8,9 @@ from pystencils.backend.constants import PsConstant
 
 from pystencils.backend.ast.expressions import PsBufferAcc, PsExpression
 from pystencils.backend.ast.structural import PsAssignment, PsBlock, PsLoop
-from pystencils.backend.kernelfunction import KernelFunction
 
 from pystencils.types.quick import SInt, Fp
-from pystencils.backend.jit import LegacyCpuJit
+from pystencils.jit import LegacyCpuJit
 
 import numpy as np
 
@@ -45,7 +44,7 @@ def test_pairwise_addition():
         PsBlock([update])
     )
 
-    func = KernelFunction(PsBlock([loop]), Target.CPU, "kernel", set())
+    func = Kernel(PsBlock([loop]), Target.CPU, "kernel", set())
 
     # sizes_constraint = PsKernelParamsConstraint(
     #     u.shape[0].eq(2 * v.shape[0]),
diff --git a/tests/nbackend/test_vectorization.py b/tests/nbackend/test_vectorization.py
index 55330c9ee8d5d675379418748aa085ab4ce3ae73..a4825669c0d930da1a5962d14e66a1cc0c457d8c 100644
--- a/tests/nbackend/test_vectorization.py
+++ b/tests/nbackend/test_vectorization.py
@@ -19,8 +19,8 @@ from pystencils.backend.transformations import (
     LowerToC,
 )
 from pystencils.backend.constants import PsConstant
-from pystencils.backend.kernelfunction import create_cpu_kernel_function
-from pystencils.backend.jit import LegacyCpuJit
+from pystencils.codegen.driver import create_cpu_kernel_function
+from pystencils.jit import LegacyCpuJit
 
 from pystencils import Target, fields, Assignment, Field
 from pystencils.field import create_numpy_array_with_layout