diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 90146e596123eedc73eea5f813964f7593b6f43f..7f1b58fde5078725eb93ad5c8cfc5f15308be4e3 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -6,7 +6,7 @@ import sympy as sp from functools import reduce from pystencils import Field -from pystencils.backend import KernelFunction +from pystencils.codegen import Kernel from pystencils.types import ( create_type, UserTypeSpec, @@ -237,7 +237,7 @@ class SfgBasicComposer(SfgIComposer): return cls def kernel_function( - self, name: str, ast_or_kernel_handle: KernelFunction | SfgKernelHandle + self, name: str, ast_or_kernel_handle: Kernel | SfgKernelHandle ): """Create a function comprising just a single kernel call. @@ -247,7 +247,7 @@ class SfgBasicComposer(SfgIComposer): if self._ctx.get_function(name) is not None: raise ValueError(f"Function {name} already exists.") - if isinstance(ast_or_kernel_handle, KernelFunction): + if isinstance(ast_or_kernel_handle, Kernel): khandle = self._ctx.default_kernel_namespace.add(ast_or_kernel_handle) tree = SfgKernelCallNode(khandle) elif isinstance(ast_or_kernel_handle, SfgKernelHandle): diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py index eb05ff6efa8eb8166c6d5f81e8e4596c6c797704..adf7508e705fc9d909e362433bbc13d5048d6475 100644 --- a/src/pystencilssfg/emission/printers.py +++ b/src/pystencilssfg/emission/printers.py @@ -3,7 +3,7 @@ from __future__ import annotations from textwrap import indent from itertools import chain, repeat, cycle -from pystencils import KernelFunction +from pystencils.codegen import Kernel from pystencils.backend.emission import emit_code from ..context import SfgContext @@ -233,8 +233,8 @@ class SfgImplPrinter(SfgGeneralPrinter): code += f"\n}} // namespace {kns.name}\n" return code - @visit.case(KernelFunction) - def kernel(self, kfunc: KernelFunction) -> str: + @visit.case(Kernel) + def kernel(self, kfunc: Kernel) -> str: return emit_code(kfunc) @visit.case(SfgFunction) diff --git a/src/pystencilssfg/ir/call_tree.py b/src/pystencilssfg/ir/call_tree.py index c6f4951db4397d356c80340ae99f4bf2b8ef1b8e..a5d2c5a35b1795817305515b74797c2bf3f2b91b 100644 --- a/src/pystencilssfg/ir/call_tree.py +++ b/src/pystencilssfg/ir/call_tree.py @@ -226,10 +226,10 @@ class SfgCudaKernelInvocation(SfgCallTreeLeaf): depends: set[SfgVar], ): from pystencils import Target - from pystencils.backend.kernelfunction import GpuKernelFunction + from pystencils.codegen import GpuKernel func = kernel_handle.get_kernel_function() - if not (isinstance(func, GpuKernelFunction) and func.target == Target.CUDA): + if not (isinstance(func, GpuKernel) and func.target == Target.CUDA): raise ValueError( "An `SfgCudaKernelInvocation` node can only call a CUDA kernel." ) diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index d9d59911464e885e528cba6fa1b0e9ad9f8511b8..aa3cd2732f62f5b9b50131b4e1ae1b48aa23e4ce 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -10,7 +10,7 @@ import sympy as sp from pystencils import Field from pystencils.types import deconstify, PsType -from pystencils.backend.properties import FieldBasePtr, FieldShape, FieldStride +from pystencils.codegen.properties import FieldBasePtr, FieldShape, FieldStride from ..exceptions import SfgException diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index 13c4b5092e2d5926ecdd549eab45737bf05fc625..ea43ac8e06cd7520c75eb266c8ff9008ca7132a0 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -7,10 +7,7 @@ from dataclasses import replace from itertools import chain from pystencils import CreateKernelConfig, create_kernel, Field -from pystencils.backend.kernelfunction import ( - KernelFunction, - KernelParameter, -) +from pystencils.codegen import Kernel, Parameter from pystencils.types import PsType, PsCustomType from ..lang import SfgVar, HeaderFile, void @@ -68,7 +65,7 @@ class SfgKernelNamespace: def __init__(self, ctx: SfgContext, name: str): self._ctx = ctx self._name = name - self._kernel_functions: dict[str, KernelFunction] = dict() + self._kernel_functions: dict[str, Kernel] = dict() @property def name(self): @@ -78,7 +75,7 @@ class SfgKernelNamespace: def kernel_functions(self): yield from self._kernel_functions.values() - def get_kernel_function(self, khandle: SfgKernelHandle) -> KernelFunction: + def get_kernel_function(self, khandle: SfgKernelHandle) -> Kernel: if khandle.kernel_namespace is not self: raise ValueError( f"Kernel handle does not belong to this namespace: {khandle}" @@ -86,7 +83,7 @@ class SfgKernelNamespace: return self._kernel_functions[khandle.kernel_name] - def add(self, kernel: KernelFunction, name: str | None = None): + def add(self, kernel: Kernel, name: str | None = None): """Adds an existing pystencils AST to this namespace. If a name is specified, the AST's function name is changed.""" if name is not None: @@ -142,7 +139,7 @@ class SfgKernelHandle: ctx: SfgContext, name: str, namespace: SfgKernelNamespace, - parameters: Sequence[KernelParameter], + parameters: Sequence[Parameter], ): self._ctx = ctx self._name = name @@ -186,11 +183,11 @@ class SfgKernelHandle: def fields(self): return self._fields - def get_kernel_function(self) -> KernelFunction: + def get_kernel_function(self) -> Kernel: return self._namespace.get_kernel_function(self) -SymbolLike_T = TypeVar("SymbolLike_T", bound=KernelParameter) +SymbolLike_T = TypeVar("SymbolLike_T", bound=Parameter) class SfgKernelParamVar(SfgVar): @@ -198,12 +195,12 @@ class SfgKernelParamVar(SfgVar): """Cast pystencils- or SymPy-native symbol-like objects as a `SfgVar`.""" - def __init__(self, param: KernelParameter): + def __init__(self, param: Parameter): self._param = param super().__init__(param.name, param.dtype) @property - def wrapped(self) -> KernelParameter: + def wrapped(self) -> Parameter: return self._param def _args(self):