From 91b9c7b41e07ed386c0bc687e7fec4f57a368332 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 13 Jan 2025 09:04:11 +0100
Subject: [PATCH] add missing type annotations

---
 src/pystencils/codegen/driver.py | 20 +++++++++++++-------
 1 file changed, 13 insertions(+), 7 deletions(-)

diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py
index 2ab906366..7bdec96cc 100644
--- a/src/pystencils/codegen/driver.py
+++ b/src/pystencils/codegen/driver.py
@@ -80,7 +80,9 @@ def create_kernel(
     return driver(assignments)
 
 
-def get_driver(cfg: CreateKernelConfig, *, retain_intermediates: bool = False):
+def get_driver(
+    cfg: CreateKernelConfig, *, retain_intermediates: bool = False
+) -> DefaultKernelCreationDriver:
     """Create a code generation driver object from the given configuration.
 
     Args:
@@ -126,7 +128,7 @@ class DefaultKernelCreationDriver:
     def __call__(
         self,
         assignments: AssignmentCollection | Sequence[AssignmentBase] | AssignmentBase,
-    ):
+    ) -> Kernel:
         kernel_body = self.parse_kernel_body(assignments)
 
         match self._platform:
@@ -241,7 +243,7 @@ class DefaultKernelCreationDriver:
 
         return kernel_body
 
-    def _transform_for_cpu(self, kernel_ast: PsBlock):
+    def _transform_for_cpu(self, kernel_ast: PsBlock) -> PsBlock:
         canonicalize = CanonicalizeSymbols(self._ctx, True)
         kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
 
@@ -394,7 +396,7 @@ def create_cpu_kernel_function(
     function_name: str,
     target_spec: Target,
     jit: JitBase,
-):
+) -> Kernel:
     undef_symbols = collect_undefined_symbols(body)
 
     params = _get_function_params(ctx, undef_symbols)
@@ -413,7 +415,7 @@ def create_gpu_kernel_function(
     function_name: str,
     target_spec: Target,
     jit: JitBase,
-):
+) -> GpuKernel:
     undef_symbols = collect_undefined_symbols(body)
 
     if threads_range is not None:
@@ -436,7 +438,9 @@ def create_gpu_kernel_function(
     return kfunc
 
 
-def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]):
+def _get_function_params(
+    ctx: KernelCreationContext, symbols: Iterable[PsSymbol]
+) -> list[Parameter]:
     params: list[Parameter] = []
 
     from pystencils.backend.memory import BufferBasePtr
@@ -456,7 +460,9 @@ def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]
     return params
 
 
-def _get_headers(ctx: KernelCreationContext, platform: Platform, body: PsBlock):
+def _get_headers(
+    ctx: KernelCreationContext, platform: Platform, body: PsBlock
+) -> set[str]:
     req_headers = collect_required_headers(body)
     req_headers |= platform.required_headers
     req_headers |= ctx.required_headers
-- 
GitLab