Skip to content
Snippets Groups Projects
Commit f1a3714c authored by Frederik Hennig's avatar Frederik Hennig
Browse files

code format and export

parent 0f0e0dfe
No related merge requests found
Pipeline #63740 failed with stages
in 3 minutes and 46 seconds
......@@ -10,6 +10,7 @@ from .cache import clear_cache
from .config import CreateKernelConfig
from .kernel_decorator import kernel, kernel_config
from .kernelcreation import create_kernel
from .backend.kernelfunction import KernelFunction
from .slicing import make_slice
from .spatial_coordinates import (
x_,
......@@ -36,6 +37,7 @@ __all__ = [
"make_slice",
"CreateKernelConfig",
"create_kernel",
"KernelFunction",
"Target",
"Backend",
"show_code",
......
......@@ -2,7 +2,13 @@ from typing import cast
from .enums import Target
from .config import CreateKernelConfig
from .backend import KernelFunction, KernelParameter, FieldShapeParam, FieldStrideParam, FieldPointerParam
from .backend import (
KernelFunction,
KernelParameter,
FieldShapeParam,
FieldStrideParam,
FieldPointerParam,
)
from .backend.symbols import PsSymbol
from .backend.jit import JitBase
from .backend.ast.structural import PsBlock
......@@ -26,6 +32,7 @@ from .sympyextensions import AssignmentCollection, Assignment
__all__ = ["create_kernel"]
def create_kernel(
assignments: AssignmentCollection | list[Assignment],
config: CreateKernelConfig = CreateKernelConfig(),
......@@ -81,10 +88,18 @@ def create_kernel(
# - Loop Splitting, Tiling, Blocking
assert config.jit is not None
return create_kernel_function(ctx, kernel_ast, config.function_name, config.target, config.jit)
return create_kernel_function(
ctx, kernel_ast, config.function_name, config.target, config.jit
)
def create_kernel_function(ctx: KernelCreationContext, body: PsBlock, name: str, target_spec: Target, jit: JitBase):
def create_kernel_function(
ctx: KernelCreationContext,
body: PsBlock,
name: str,
target_spec: Target,
jit: JitBase,
):
undef_symbols = collect_undefined_symbols(body)
params = []
......@@ -101,18 +116,12 @@ def create_kernel_function(ctx: KernelCreationContext, body: PsBlock, name: str,
params.append(FieldPointerParam(name, symb.get_dtype(), field))
case PsSymbol(name, _):
params.append(KernelParameter(name, symb.get_dtype()))
params.sort(key=lambda p: p.name)
req_headers = collect_required_headers(body)
req_headers |= ctx.required_headers
return KernelFunction(
body,
target_spec,
name,
params,
req_headers,
ctx.constraints,
jit
body, target_spec, name, params, req_headers, ctx.constraints, jit
)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment