Skip to content
Snippets Groups Projects
Commit 91b9c7b4 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

add missing type annotations

parent c6e2ec20
No related branches found
No related tags found
1 merge request!433Consolidate codegen and JIT modules.
Pipeline #71796 failed
...@@ -80,7 +80,9 @@ def create_kernel( ...@@ -80,7 +80,9 @@ def create_kernel(
return driver(assignments) return driver(assignments)
def get_driver(cfg: CreateKernelConfig, *, retain_intermediates: bool = False): def get_driver(
cfg: CreateKernelConfig, *, retain_intermediates: bool = False
) -> DefaultKernelCreationDriver:
"""Create a code generation driver object from the given configuration. """Create a code generation driver object from the given configuration.
Args: Args:
...@@ -126,7 +128,7 @@ class DefaultKernelCreationDriver: ...@@ -126,7 +128,7 @@ class DefaultKernelCreationDriver:
def __call__( def __call__(
self, self,
assignments: AssignmentCollection | Sequence[AssignmentBase] | AssignmentBase, assignments: AssignmentCollection | Sequence[AssignmentBase] | AssignmentBase,
): ) -> Kernel:
kernel_body = self.parse_kernel_body(assignments) kernel_body = self.parse_kernel_body(assignments)
match self._platform: match self._platform:
...@@ -241,7 +243,7 @@ class DefaultKernelCreationDriver: ...@@ -241,7 +243,7 @@ class DefaultKernelCreationDriver:
return kernel_body return kernel_body
def _transform_for_cpu(self, kernel_ast: PsBlock): def _transform_for_cpu(self, kernel_ast: PsBlock) -> PsBlock:
canonicalize = CanonicalizeSymbols(self._ctx, True) canonicalize = CanonicalizeSymbols(self._ctx, True)
kernel_ast = cast(PsBlock, canonicalize(kernel_ast)) kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
...@@ -394,7 +396,7 @@ def create_cpu_kernel_function( ...@@ -394,7 +396,7 @@ def create_cpu_kernel_function(
function_name: str, function_name: str,
target_spec: Target, target_spec: Target,
jit: JitBase, jit: JitBase,
): ) -> Kernel:
undef_symbols = collect_undefined_symbols(body) undef_symbols = collect_undefined_symbols(body)
params = _get_function_params(ctx, undef_symbols) params = _get_function_params(ctx, undef_symbols)
...@@ -413,7 +415,7 @@ def create_gpu_kernel_function( ...@@ -413,7 +415,7 @@ def create_gpu_kernel_function(
function_name: str, function_name: str,
target_spec: Target, target_spec: Target,
jit: JitBase, jit: JitBase,
): ) -> GpuKernel:
undef_symbols = collect_undefined_symbols(body) undef_symbols = collect_undefined_symbols(body)
if threads_range is not None: if threads_range is not None:
...@@ -436,7 +438,9 @@ def create_gpu_kernel_function( ...@@ -436,7 +438,9 @@ def create_gpu_kernel_function(
return kfunc return kfunc
def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]): def _get_function_params(
ctx: KernelCreationContext, symbols: Iterable[PsSymbol]
) -> list[Parameter]:
params: list[Parameter] = [] params: list[Parameter] = []
from pystencils.backend.memory import BufferBasePtr from pystencils.backend.memory import BufferBasePtr
...@@ -456,7 +460,9 @@ def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol] ...@@ -456,7 +460,9 @@ def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]
return params return params
def _get_headers(ctx: KernelCreationContext, platform: Platform, body: PsBlock): def _get_headers(
ctx: KernelCreationContext, platform: Platform, body: PsBlock
) -> set[str]:
req_headers = collect_required_headers(body) req_headers = collect_required_headers(body)
req_headers |= platform.required_headers req_headers |= platform.required_headers
req_headers |= ctx.required_headers req_headers |= ctx.required_headers
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment