Skip to content
Snippets Groups Projects
Commit f1a3714c authored by Frederik Hennig's avatar Frederik Hennig
Browse files

code format and export

parent 0f0e0dfe
No related branches found
No related tags found
No related merge requests found
Pipeline #63740 failed
......@@ -10,6 +10,7 @@ from .cache import clear_cache
from .config import CreateKernelConfig
from .kernel_decorator import kernel, kernel_config
from .kernelcreation import create_kernel
from .backend.kernelfunction import KernelFunction
from .slicing import make_slice
from .spatial_coordinates import (
x_,
......@@ -36,6 +37,7 @@ __all__ = [
"make_slice",
"CreateKernelConfig",
"create_kernel",
"KernelFunction",
"Target",
"Backend",
"show_code",
......
......@@ -2,7 +2,13 @@ from typing import cast
from .enums import Target
from .config import CreateKernelConfig
from .backend import KernelFunction, KernelParameter, FieldShapeParam, FieldStrideParam, FieldPointerParam
from .backend import (
KernelFunction,
KernelParameter,
FieldShapeParam,
FieldStrideParam,
FieldPointerParam,
)
from .backend.symbols import PsSymbol
from .backend.jit import JitBase
from .backend.ast.structural import PsBlock
......@@ -26,6 +32,7 @@ from .sympyextensions import AssignmentCollection, Assignment
__all__ = ["create_kernel"]
def create_kernel(
assignments: AssignmentCollection | list[Assignment],
config: CreateKernelConfig = CreateKernelConfig(),
......@@ -81,10 +88,18 @@ def create_kernel(
# - Loop Splitting, Tiling, Blocking
assert config.jit is not None
return create_kernel_function(ctx, kernel_ast, config.function_name, config.target, config.jit)
return create_kernel_function(
ctx, kernel_ast, config.function_name, config.target, config.jit
)
def create_kernel_function(ctx: KernelCreationContext, body: PsBlock, name: str, target_spec: Target, jit: JitBase):
def create_kernel_function(
ctx: KernelCreationContext,
body: PsBlock,
name: str,
target_spec: Target,
jit: JitBase,
):
undef_symbols = collect_undefined_symbols(body)
params = []
......@@ -101,18 +116,12 @@ def create_kernel_function(ctx: KernelCreationContext, body: PsBlock, name: str,
params.append(FieldPointerParam(name, symb.get_dtype(), field))
case PsSymbol(name, _):
params.append(KernelParameter(name, symb.get_dtype()))
params.sort(key=lambda p: p.name)
req_headers = collect_required_headers(body)
req_headers |= ctx.required_headers
return KernelFunction(
body,
target_spec,
name,
params,
req_headers,
ctx.constraints,
jit
body, target_spec, name, params, req_headers, ctx.constraints, jit
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment