From 6386fac2cb719087275f212924a76058b2793e53 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 14 Jan 2025 13:46:55 +0100
Subject: [PATCH] Adapt to pystencils codegen API changes

---
 src/pystencilssfg/composer/basic_composer.py |  6 +++---
 src/pystencilssfg/emission/printers.py       |  6 +++---
 src/pystencilssfg/ir/call_tree.py            |  4 ++--
 src/pystencilssfg/ir/postprocessing.py       |  2 +-
 src/pystencilssfg/ir/source_components.py    | 21 +++++++++-----------
 5 files changed, 18 insertions(+), 21 deletions(-)

diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py
index 90146e5..7f1b58f 100644
--- a/src/pystencilssfg/composer/basic_composer.py
+++ b/src/pystencilssfg/composer/basic_composer.py
@@ -6,7 +6,7 @@ import sympy as sp
 from functools import reduce
 
 from pystencils import Field
-from pystencils.backend import KernelFunction
+from pystencils.codegen import Kernel
 from pystencils.types import (
     create_type,
     UserTypeSpec,
@@ -237,7 +237,7 @@ class SfgBasicComposer(SfgIComposer):
         return cls
 
     def kernel_function(
-        self, name: str, ast_or_kernel_handle: KernelFunction | SfgKernelHandle
+        self, name: str, ast_or_kernel_handle: Kernel | SfgKernelHandle
     ):
         """Create a function comprising just a single kernel call.
 
@@ -247,7 +247,7 @@ class SfgBasicComposer(SfgIComposer):
         if self._ctx.get_function(name) is not None:
             raise ValueError(f"Function {name} already exists.")
 
-        if isinstance(ast_or_kernel_handle, KernelFunction):
+        if isinstance(ast_or_kernel_handle, Kernel):
             khandle = self._ctx.default_kernel_namespace.add(ast_or_kernel_handle)
             tree = SfgKernelCallNode(khandle)
         elif isinstance(ast_or_kernel_handle, SfgKernelHandle):
diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py
index eb05ff6..adf7508 100644
--- a/src/pystencilssfg/emission/printers.py
+++ b/src/pystencilssfg/emission/printers.py
@@ -3,7 +3,7 @@ from __future__ import annotations
 from textwrap import indent
 from itertools import chain, repeat, cycle
 
-from pystencils import KernelFunction
+from pystencils.codegen import Kernel
 from pystencils.backend.emission import emit_code
 
 from ..context import SfgContext
@@ -233,8 +233,8 @@ class SfgImplPrinter(SfgGeneralPrinter):
         code += f"\n}} // namespace {kns.name}\n"
         return code
 
-    @visit.case(KernelFunction)
-    def kernel(self, kfunc: KernelFunction) -> str:
+    @visit.case(Kernel)
+    def kernel(self, kfunc: Kernel) -> str:
         return emit_code(kfunc)
 
     @visit.case(SfgFunction)
diff --git a/src/pystencilssfg/ir/call_tree.py b/src/pystencilssfg/ir/call_tree.py
index c6f4951..a5d2c5a 100644
--- a/src/pystencilssfg/ir/call_tree.py
+++ b/src/pystencilssfg/ir/call_tree.py
@@ -226,10 +226,10 @@ class SfgCudaKernelInvocation(SfgCallTreeLeaf):
         depends: set[SfgVar],
     ):
         from pystencils import Target
-        from pystencils.backend.kernelfunction import GpuKernelFunction
+        from pystencils.codegen import GpuKernel
 
         func = kernel_handle.get_kernel_function()
-        if not (isinstance(func, GpuKernelFunction) and func.target == Target.CUDA):
+        if not (isinstance(func, GpuKernel) and func.target == Target.CUDA):
             raise ValueError(
                 "An `SfgCudaKernelInvocation` node can only call a CUDA kernel."
             )
diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py
index d9d5991..aa3cd27 100644
--- a/src/pystencilssfg/ir/postprocessing.py
+++ b/src/pystencilssfg/ir/postprocessing.py
@@ -10,7 +10,7 @@ import sympy as sp
 
 from pystencils import Field
 from pystencils.types import deconstify, PsType
-from pystencils.backend.properties import FieldBasePtr, FieldShape, FieldStride
+from pystencils.codegen.properties import FieldBasePtr, FieldShape, FieldStride
 
 from ..exceptions import SfgException
 
diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py
index 13c4b50..ea43ac8 100644
--- a/src/pystencilssfg/ir/source_components.py
+++ b/src/pystencilssfg/ir/source_components.py
@@ -7,10 +7,7 @@ from dataclasses import replace
 from itertools import chain
 
 from pystencils import CreateKernelConfig, create_kernel, Field
-from pystencils.backend.kernelfunction import (
-    KernelFunction,
-    KernelParameter,
-)
+from pystencils.codegen import Kernel, Parameter
 from pystencils.types import PsType, PsCustomType
 
 from ..lang import SfgVar, HeaderFile, void
@@ -68,7 +65,7 @@ class SfgKernelNamespace:
     def __init__(self, ctx: SfgContext, name: str):
         self._ctx = ctx
         self._name = name
-        self._kernel_functions: dict[str, KernelFunction] = dict()
+        self._kernel_functions: dict[str, Kernel] = dict()
 
     @property
     def name(self):
@@ -78,7 +75,7 @@ class SfgKernelNamespace:
     def kernel_functions(self):
         yield from self._kernel_functions.values()
 
-    def get_kernel_function(self, khandle: SfgKernelHandle) -> KernelFunction:
+    def get_kernel_function(self, khandle: SfgKernelHandle) -> Kernel:
         if khandle.kernel_namespace is not self:
             raise ValueError(
                 f"Kernel handle does not belong to this namespace: {khandle}"
@@ -86,7 +83,7 @@ class SfgKernelNamespace:
 
         return self._kernel_functions[khandle.kernel_name]
 
-    def add(self, kernel: KernelFunction, name: str | None = None):
+    def add(self, kernel: Kernel, name: str | None = None):
         """Adds an existing pystencils AST to this namespace.
         If a name is specified, the AST's function name is changed."""
         if name is not None:
@@ -142,7 +139,7 @@ class SfgKernelHandle:
         ctx: SfgContext,
         name: str,
         namespace: SfgKernelNamespace,
-        parameters: Sequence[KernelParameter],
+        parameters: Sequence[Parameter],
     ):
         self._ctx = ctx
         self._name = name
@@ -186,11 +183,11 @@ class SfgKernelHandle:
     def fields(self):
         return self._fields
 
-    def get_kernel_function(self) -> KernelFunction:
+    def get_kernel_function(self) -> Kernel:
         return self._namespace.get_kernel_function(self)
 
 
-SymbolLike_T = TypeVar("SymbolLike_T", bound=KernelParameter)
+SymbolLike_T = TypeVar("SymbolLike_T", bound=Parameter)
 
 
 class SfgKernelParamVar(SfgVar):
@@ -198,12 +195,12 @@ class SfgKernelParamVar(SfgVar):
 
     """Cast pystencils- or SymPy-native symbol-like objects as a `SfgVar`."""
 
-    def __init__(self, param: KernelParameter):
+    def __init__(self, param: Parameter):
         self._param = param
         super().__init__(param.name, param.dtype)
 
     @property
-    def wrapped(self) -> KernelParameter:
+    def wrapped(self) -> Parameter:
         return self._param
 
     def _args(self):
-- 
GitLab