diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py
index 2ab9063667920f8b7651934cbcdad07599858758..7bdec96cc0bd32eac08365b24186d319c27fb36a 100644
--- a/src/pystencils/codegen/driver.py
+++ b/src/pystencils/codegen/driver.py
@@ -80,7 +80,9 @@ def create_kernel(
     return driver(assignments)
 
 
-def get_driver(cfg: CreateKernelConfig, *, retain_intermediates: bool = False):
+def get_driver(
+    cfg: CreateKernelConfig, *, retain_intermediates: bool = False
+) -> DefaultKernelCreationDriver:
     """Create a code generation driver object from the given configuration.
 
     Args:
@@ -126,7 +128,7 @@ class DefaultKernelCreationDriver:
     def __call__(
         self,
         assignments: AssignmentCollection | Sequence[AssignmentBase] | AssignmentBase,
-    ):
+    ) -> Kernel:
         kernel_body = self.parse_kernel_body(assignments)
 
         match self._platform:
@@ -241,7 +243,7 @@ class DefaultKernelCreationDriver:
 
         return kernel_body
 
-    def _transform_for_cpu(self, kernel_ast: PsBlock):
+    def _transform_for_cpu(self, kernel_ast: PsBlock) -> PsBlock:
         canonicalize = CanonicalizeSymbols(self._ctx, True)
         kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
 
@@ -394,7 +396,7 @@ def create_cpu_kernel_function(
     function_name: str,
     target_spec: Target,
     jit: JitBase,
-):
+) -> Kernel:
     undef_symbols = collect_undefined_symbols(body)
 
     params = _get_function_params(ctx, undef_symbols)
@@ -413,7 +415,7 @@ def create_gpu_kernel_function(
     function_name: str,
     target_spec: Target,
     jit: JitBase,
-):
+) -> GpuKernel:
     undef_symbols = collect_undefined_symbols(body)
 
     if threads_range is not None:
@@ -436,7 +438,9 @@ def create_gpu_kernel_function(
     return kfunc
 
 
-def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]):
+def _get_function_params(
+    ctx: KernelCreationContext, symbols: Iterable[PsSymbol]
+) -> list[Parameter]:
     params: list[Parameter] = []
 
     from pystencils.backend.memory import BufferBasePtr
@@ -456,7 +460,9 @@ def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]
     return params
 
 
-def _get_headers(ctx: KernelCreationContext, platform: Platform, body: PsBlock):
+def _get_headers(
+    ctx: KernelCreationContext, platform: Platform, body: PsBlock
+) -> set[str]:
     req_headers = collect_required_headers(body)
     req_headers |= platform.required_headers
     req_headers |= ctx.required_headers