diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index 2ab9063667920f8b7651934cbcdad07599858758..7bdec96cc0bd32eac08365b24186d319c27fb36a 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -80,7 +80,9 @@ def create_kernel( return driver(assignments) -def get_driver(cfg: CreateKernelConfig, *, retain_intermediates: bool = False): +def get_driver( + cfg: CreateKernelConfig, *, retain_intermediates: bool = False +) -> DefaultKernelCreationDriver: """Create a code generation driver object from the given configuration. Args: @@ -126,7 +128,7 @@ class DefaultKernelCreationDriver: def __call__( self, assignments: AssignmentCollection | Sequence[AssignmentBase] | AssignmentBase, - ): + ) -> Kernel: kernel_body = self.parse_kernel_body(assignments) match self._platform: @@ -241,7 +243,7 @@ class DefaultKernelCreationDriver: return kernel_body - def _transform_for_cpu(self, kernel_ast: PsBlock): + def _transform_for_cpu(self, kernel_ast: PsBlock) -> PsBlock: canonicalize = CanonicalizeSymbols(self._ctx, True) kernel_ast = cast(PsBlock, canonicalize(kernel_ast)) @@ -394,7 +396,7 @@ def create_cpu_kernel_function( function_name: str, target_spec: Target, jit: JitBase, -): +) -> Kernel: undef_symbols = collect_undefined_symbols(body) params = _get_function_params(ctx, undef_symbols) @@ -413,7 +415,7 @@ def create_gpu_kernel_function( function_name: str, target_spec: Target, jit: JitBase, -): +) -> GpuKernel: undef_symbols = collect_undefined_symbols(body) if threads_range is not None: @@ -436,7 +438,9 @@ def create_gpu_kernel_function( return kfunc -def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]): +def _get_function_params( + ctx: KernelCreationContext, symbols: Iterable[PsSymbol] +) -> list[Parameter]: params: list[Parameter] = [] from pystencils.backend.memory import BufferBasePtr @@ -456,7 +460,9 @@ def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol] return params -def _get_headers(ctx: KernelCreationContext, platform: Platform, body: PsBlock): +def _get_headers( + ctx: KernelCreationContext, platform: Platform, body: PsBlock +) -> set[str]: req_headers = collect_required_headers(body) req_headers |= platform.required_headers req_headers |= ctx.required_headers