Skip to content
Snippets Groups Projects
Commit 91b9c7b4 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

add missing type annotations

parent c6e2ec20
No related branches found
No related tags found
1 merge request!433Consolidate codegen and JIT modules.
Pipeline #71796 failed
......@@ -80,7 +80,9 @@ def create_kernel(
return driver(assignments)
def get_driver(cfg: CreateKernelConfig, *, retain_intermediates: bool = False):
def get_driver(
cfg: CreateKernelConfig, *, retain_intermediates: bool = False
) -> DefaultKernelCreationDriver:
"""Create a code generation driver object from the given configuration.
Args:
......@@ -126,7 +128,7 @@ class DefaultKernelCreationDriver:
def __call__(
self,
assignments: AssignmentCollection | Sequence[AssignmentBase] | AssignmentBase,
):
) -> Kernel:
kernel_body = self.parse_kernel_body(assignments)
match self._platform:
......@@ -241,7 +243,7 @@ class DefaultKernelCreationDriver:
return kernel_body
def _transform_for_cpu(self, kernel_ast: PsBlock):
def _transform_for_cpu(self, kernel_ast: PsBlock) -> PsBlock:
canonicalize = CanonicalizeSymbols(self._ctx, True)
kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
......@@ -394,7 +396,7 @@ def create_cpu_kernel_function(
function_name: str,
target_spec: Target,
jit: JitBase,
):
) -> Kernel:
undef_symbols = collect_undefined_symbols(body)
params = _get_function_params(ctx, undef_symbols)
......@@ -413,7 +415,7 @@ def create_gpu_kernel_function(
function_name: str,
target_spec: Target,
jit: JitBase,
):
) -> GpuKernel:
undef_symbols = collect_undefined_symbols(body)
if threads_range is not None:
......@@ -436,7 +438,9 @@ def create_gpu_kernel_function(
return kfunc
def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]):
def _get_function_params(
ctx: KernelCreationContext, symbols: Iterable[PsSymbol]
) -> list[Parameter]:
params: list[Parameter] = []
from pystencils.backend.memory import BufferBasePtr
......@@ -456,7 +460,9 @@ def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]
return params
def _get_headers(ctx: KernelCreationContext, platform: Platform, body: PsBlock):
def _get_headers(
ctx: KernelCreationContext, platform: Platform, body: PsBlock
) -> set[str]:
req_headers = collect_required_headers(body)
req_headers |= platform.required_headers
req_headers |= ctx.required_headers
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment