Skip to content
Snippets Groups Projects
Commit 4d7c38f5 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

adapt to API changes with introduction of codegen module in pystencils

parent 448e9087
1 merge request!11Adapt to pystencils codegen API changes
Pipeline #71844 passed with stages
in 2 minutes and 44 seconds
...@@ -6,7 +6,7 @@ import sympy as sp ...@@ -6,7 +6,7 @@ import sympy as sp
from functools import reduce from functools import reduce
from pystencils import Field from pystencils import Field
from pystencils.backend import KernelFunction from pystencils.codegen import Kernel
from pystencils.types import ( from pystencils.types import (
create_type, create_type,
UserTypeSpec, UserTypeSpec,
...@@ -237,7 +237,7 @@ class SfgBasicComposer(SfgIComposer): ...@@ -237,7 +237,7 @@ class SfgBasicComposer(SfgIComposer):
return cls return cls
def kernel_function( def kernel_function(
self, name: str, ast_or_kernel_handle: KernelFunction | SfgKernelHandle self, name: str, ast_or_kernel_handle: Kernel | SfgKernelHandle
): ):
"""Create a function comprising just a single kernel call. """Create a function comprising just a single kernel call.
...@@ -247,7 +247,7 @@ class SfgBasicComposer(SfgIComposer): ...@@ -247,7 +247,7 @@ class SfgBasicComposer(SfgIComposer):
if self._ctx.get_function(name) is not None: if self._ctx.get_function(name) is not None:
raise ValueError(f"Function {name} already exists.") raise ValueError(f"Function {name} already exists.")
if isinstance(ast_or_kernel_handle, KernelFunction): if isinstance(ast_or_kernel_handle, Kernel):
khandle = self._ctx.default_kernel_namespace.add(ast_or_kernel_handle) khandle = self._ctx.default_kernel_namespace.add(ast_or_kernel_handle)
tree = SfgKernelCallNode(khandle) tree = SfgKernelCallNode(khandle)
elif isinstance(ast_or_kernel_handle, SfgKernelHandle): elif isinstance(ast_or_kernel_handle, SfgKernelHandle):
......
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
from textwrap import indent from textwrap import indent
from itertools import chain, repeat, cycle from itertools import chain, repeat, cycle
from pystencils import KernelFunction from pystencils.codegen import Kernel
from pystencils.backend.emission import emit_code from pystencils.backend.emission import emit_code
from ..context import SfgContext from ..context import SfgContext
...@@ -233,8 +233,8 @@ class SfgImplPrinter(SfgGeneralPrinter): ...@@ -233,8 +233,8 @@ class SfgImplPrinter(SfgGeneralPrinter):
code += f"\n}} // namespace {kns.name}\n" code += f"\n}} // namespace {kns.name}\n"
return code return code
@visit.case(KernelFunction) @visit.case(Kernel)
def kernel(self, kfunc: KernelFunction) -> str: def kernel(self, kfunc: Kernel) -> str:
return emit_code(kfunc) return emit_code(kfunc)
@visit.case(SfgFunction) @visit.case(SfgFunction)
......
...@@ -226,10 +226,10 @@ class SfgCudaKernelInvocation(SfgCallTreeLeaf): ...@@ -226,10 +226,10 @@ class SfgCudaKernelInvocation(SfgCallTreeLeaf):
depends: set[SfgVar], depends: set[SfgVar],
): ):
from pystencils import Target from pystencils import Target
from pystencils.backend.kernelfunction import GpuKernelFunction from pystencils.codegen import GpuKernel
func = kernel_handle.get_kernel_function() func = kernel_handle.get_kernel_function()
if not (isinstance(func, GpuKernelFunction) and func.target == Target.CUDA): if not (isinstance(func, GpuKernel) and func.target == Target.CUDA):
raise ValueError( raise ValueError(
"An `SfgCudaKernelInvocation` node can only call a CUDA kernel." "An `SfgCudaKernelInvocation` node can only call a CUDA kernel."
) )
......
...@@ -10,7 +10,7 @@ import sympy as sp ...@@ -10,7 +10,7 @@ import sympy as sp
from pystencils import Field from pystencils import Field
from pystencils.types import deconstify, PsType from pystencils.types import deconstify, PsType
from pystencils.backend.properties import FieldBasePtr, FieldShape, FieldStride from pystencils.codegen.properties import FieldBasePtr, FieldShape, FieldStride
from ..exceptions import SfgException from ..exceptions import SfgException
......
...@@ -7,10 +7,7 @@ from dataclasses import replace ...@@ -7,10 +7,7 @@ from dataclasses import replace
from itertools import chain from itertools import chain
from pystencils import CreateKernelConfig, create_kernel, Field from pystencils import CreateKernelConfig, create_kernel, Field
from pystencils.backend.kernelfunction import ( from pystencils.codegen import Kernel, Parameter
KernelFunction,
KernelParameter,
)
from pystencils.types import PsType, PsCustomType from pystencils.types import PsType, PsCustomType
from ..lang import SfgVar, HeaderFile, void from ..lang import SfgVar, HeaderFile, void
...@@ -68,7 +65,7 @@ class SfgKernelNamespace: ...@@ -68,7 +65,7 @@ class SfgKernelNamespace:
def __init__(self, ctx: SfgContext, name: str): def __init__(self, ctx: SfgContext, name: str):
self._ctx = ctx self._ctx = ctx
self._name = name self._name = name
self._kernel_functions: dict[str, KernelFunction] = dict() self._kernel_functions: dict[str, Kernel] = dict()
@property @property
def name(self): def name(self):
...@@ -78,7 +75,7 @@ class SfgKernelNamespace: ...@@ -78,7 +75,7 @@ class SfgKernelNamespace:
def kernel_functions(self): def kernel_functions(self):
yield from self._kernel_functions.values() yield from self._kernel_functions.values()
def get_kernel_function(self, khandle: SfgKernelHandle) -> KernelFunction: def get_kernel_function(self, khandle: SfgKernelHandle) -> Kernel:
if khandle.kernel_namespace is not self: if khandle.kernel_namespace is not self:
raise ValueError( raise ValueError(
f"Kernel handle does not belong to this namespace: {khandle}" f"Kernel handle does not belong to this namespace: {khandle}"
...@@ -86,7 +83,7 @@ class SfgKernelNamespace: ...@@ -86,7 +83,7 @@ class SfgKernelNamespace:
return self._kernel_functions[khandle.kernel_name] return self._kernel_functions[khandle.kernel_name]
def add(self, kernel: KernelFunction, name: str | None = None): def add(self, kernel: Kernel, name: str | None = None):
"""Adds an existing pystencils AST to this namespace. """Adds an existing pystencils AST to this namespace.
If a name is specified, the AST's function name is changed.""" If a name is specified, the AST's function name is changed."""
if name is not None: if name is not None:
...@@ -142,7 +139,7 @@ class SfgKernelHandle: ...@@ -142,7 +139,7 @@ class SfgKernelHandle:
ctx: SfgContext, ctx: SfgContext,
name: str, name: str,
namespace: SfgKernelNamespace, namespace: SfgKernelNamespace,
parameters: Sequence[KernelParameter], parameters: Sequence[Parameter],
): ):
self._ctx = ctx self._ctx = ctx
self._name = name self._name = name
...@@ -186,11 +183,11 @@ class SfgKernelHandle: ...@@ -186,11 +183,11 @@ class SfgKernelHandle:
def fields(self): def fields(self):
return self._fields return self._fields
def get_kernel_function(self) -> KernelFunction: def get_kernel_function(self) -> Kernel:
return self._namespace.get_kernel_function(self) return self._namespace.get_kernel_function(self)
SymbolLike_T = TypeVar("SymbolLike_T", bound=KernelParameter) SymbolLike_T = TypeVar("SymbolLike_T", bound=Parameter)
class SfgKernelParamVar(SfgVar): class SfgKernelParamVar(SfgVar):
...@@ -198,12 +195,12 @@ class SfgKernelParamVar(SfgVar): ...@@ -198,12 +195,12 @@ class SfgKernelParamVar(SfgVar):
"""Cast pystencils- or SymPy-native symbol-like objects as a `SfgVar`.""" """Cast pystencils- or SymPy-native symbol-like objects as a `SfgVar`."""
def __init__(self, param: KernelParameter): def __init__(self, param: Parameter):
self._param = param self._param = param
super().__init__(param.name, param.dtype) super().__init__(param.name, param.dtype)
@property @property
def wrapped(self) -> KernelParameter: def wrapped(self) -> Parameter:
return self._param return self._param
def _args(self): def _args(self):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment