From 91b9c7b41e07ed386c0bc687e7fec4f57a368332 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 13 Jan 2025 09:04:11 +0100 Subject: [PATCH] add missing type annotations --- src/pystencils/codegen/driver.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index 2ab906366..7bdec96cc 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -80,7 +80,9 @@ def create_kernel( return driver(assignments) -def get_driver(cfg: CreateKernelConfig, *, retain_intermediates: bool = False): +def get_driver( + cfg: CreateKernelConfig, *, retain_intermediates: bool = False +) -> DefaultKernelCreationDriver: """Create a code generation driver object from the given configuration. Args: @@ -126,7 +128,7 @@ class DefaultKernelCreationDriver: def __call__( self, assignments: AssignmentCollection | Sequence[AssignmentBase] | AssignmentBase, - ): + ) -> Kernel: kernel_body = self.parse_kernel_body(assignments) match self._platform: @@ -241,7 +243,7 @@ class DefaultKernelCreationDriver: return kernel_body - def _transform_for_cpu(self, kernel_ast: PsBlock): + def _transform_for_cpu(self, kernel_ast: PsBlock) -> PsBlock: canonicalize = CanonicalizeSymbols(self._ctx, True) kernel_ast = cast(PsBlock, canonicalize(kernel_ast)) @@ -394,7 +396,7 @@ def create_cpu_kernel_function( function_name: str, target_spec: Target, jit: JitBase, -): +) -> Kernel: undef_symbols = collect_undefined_symbols(body) params = _get_function_params(ctx, undef_symbols) @@ -413,7 +415,7 @@ def create_gpu_kernel_function( function_name: str, target_spec: Target, jit: JitBase, -): +) -> GpuKernel: undef_symbols = collect_undefined_symbols(body) if threads_range is not None: @@ -436,7 +438,9 @@ def create_gpu_kernel_function( return kfunc -def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]): +def _get_function_params( + ctx: KernelCreationContext, symbols: Iterable[PsSymbol] +) -> list[Parameter]: params: list[Parameter] = [] from pystencils.backend.memory import BufferBasePtr @@ -456,7 +460,9 @@ def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol] return params -def _get_headers(ctx: KernelCreationContext, platform: Platform, body: PsBlock): +def _get_headers( + ctx: KernelCreationContext, platform: Platform, body: PsBlock +) -> set[str]: req_headers = collect_required_headers(body) req_headers |= platform.required_headers req_headers |= ctx.required_headers -- GitLab