From 639c590d7f583f317d13fd0ec532b0c217646226 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 23 Jul 2024 14:38:08 +0200
Subject: [PATCH] API compatibility patches for pystencils 2.0.dev0

---
 .gitignore                                    |  1 +
 apps/tutorials/codegen/HeatEquationKernel.py  | 12 +--
 .../lbmpy_walberla/additional_data_handler.py | 27 ++++--
 python/lbmpy_walberla/boundary_collection.py  |  2 +-
 python/lbmpy_walberla/function_generator.py   |  2 +-
 python/lbmpy_walberla/packing_kernels.py      | 18 ++--
 .../lbmpy_walberla/storage_specification.py   |  5 +-
 python/lbmpy_walberla/sweep_collection.py     | 32 +++++--
 python/lbmpy_walberla/utility.py              |  5 +-
 .../lbmpy_walberla/walberla_lbm_generation.py | 20 ++--
 python/pystencils_walberla/boundary.py        | 28 ++++--
 python/pystencils_walberla/compat.py          | 93 +++++++++++++++++++
 python/pystencils_walberla/jinja_filters.py   | 26 +++---
 python/pystencils_walberla/kernel_info.py     | 25 ++---
 .../pystencils_walberla/kernel_selection.py   | 23 +++--
 python/pystencils_walberla/pack_info.py       |  4 +-
 python/pystencils_walberla/sweep.py           | 15 +--
 .../templates/Boundary.tmpl.cpp               |  4 +-
 python/pystencils_walberla/utility.py         | 16 +++-
 tests/lbm_generated/InterpolationNoSlip.py    |  2 +-
 20 files changed, 259 insertions(+), 101 deletions(-)
 create mode 100644 python/pystencils_walberla/compat.py

diff --git a/.gitignore b/.gitignore
index 9b148120d..4495383c0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -64,6 +64,7 @@ logfile*.txt
 
 
 # CMake
+CMakeUserPresets.json
 /CMakeLists.txt.user
 
 # CMake build files
diff --git a/apps/tutorials/codegen/HeatEquationKernel.py b/apps/tutorials/codegen/HeatEquationKernel.py
index 024940a38..ee0c9c252 100644
--- a/apps/tutorials/codegen/HeatEquationKernel.py
+++ b/apps/tutorials/codegen/HeatEquationKernel.py
@@ -5,7 +5,7 @@ from pystencils_walberla import CodeGeneration, generate_sweep
 with CodeGeneration() as ctx:
     data_type = "float64" if ctx.double_accuracy else "float32"
 
-    u, u_tmp = ps.fields(f"u, u_tmp: {data_type}[2D]", layout='fzyx')
+    u, u_tmp = ps.fields(f"u, u_tmp: {data_type}[2D]", layout="fzyx")
     kappa = sp.Symbol("kappa")
     dx = sp.Symbol("dx")
     dt = sp.Symbol("dt")
@@ -13,15 +13,15 @@ with CodeGeneration() as ctx:
 
     discretize = ps.fd.Discretization2ndOrder(dx=dx, dt=dt)
     heat_pde_discretized = discretize(heat_pde)
-    heat_pde_discretized = heat_pde_discretized.args[1] + heat_pde_discretized.args[0].simplify()
-
+    heat_pde_discretized = (
+        heat_pde_discretized.args[1] + heat_pde_discretized.args[0].simplify()
+    )
 
     @ps.kernel
     def update():
         u_tmp.center @= heat_pde_discretized
 
-
     ac = ps.AssignmentCollection(update)
-    ac = ps.simp.simplifications.add_subexpressions_for_divisions(ac)
+    ac = ps.simp.add_subexpressions_for_divisions(ac)
 
-    generate_sweep(ctx, 'HeatEquationKernel', ac)
+    generate_sweep(ctx, "HeatEquationKernel", ac)
diff --git a/python/lbmpy_walberla/additional_data_handler.py b/python/lbmpy_walberla/additional_data_handler.py
index daaab32fd..0f1cd58c6 100644
--- a/python/lbmpy_walberla/additional_data_handler.py
+++ b/python/lbmpy_walberla/additional_data_handler.py
@@ -1,13 +1,20 @@
 from pystencils import Target
 from pystencils.stencil import inverse_direction
-from pystencils.typing import BasicType
+from pystencils.typing import create_type
+
+from pystencils_walberla.compat import IS_PYSTENCILS_2
 
 from lbmpy.advanced_streaming import AccessPdfValues, numeric_offsets, numeric_index, Timestep, is_inplace
-# until lbmpy version 1.3.2 
-try:
-    from lbmpy.advanced_streaming.indexing import MirroredStencilDirections
-except ImportError:
-    from lbmpy.custom_code_nodes import MirroredStencilDirections
+
+if IS_PYSTENCILS_2:
+    from lbmpy.lookup_tables import MirroredStencilDirections
+else:
+    # until lbmpy version 1.3.2
+    try:
+        from lbmpy.advanced_streaming.indexing import MirroredStencilDirections
+    except ImportError:
+        from lbmpy.custom_code_nodes import MirroredStencilDirections
+
 from lbmpy.boundaries.boundaryconditions import LbBoundary
 from lbmpy.boundaries import (ExtrapolationOutflow, FreeSlip, UBB, DiffusionDirichlet,
                               NoSlipLinearBouzidi, QuadraticBounceBack)
@@ -153,7 +160,7 @@ class NoSlipLinearBouzidiAdditionalDataHandler(AdditionalDataHandler):
     def __init__(self, stencil, boundary_object):
         assert isinstance(boundary_object, NoSlipLinearBouzidi)
 
-        self._dtype = BasicType(boundary_object.data_type).c_name
+        self._dtype = create_type(boundary_object.data_type).c_name
         self._blocks = "const shared_ptr<StructuredBlockForest>&, IBlock&)>"
         super(NoSlipLinearBouzidiAdditionalDataHandler, self).__init__(stencil=stencil)
 
@@ -201,7 +208,7 @@ class QuadraticBounceBackAdditionalDataHandler(AdditionalDataHandler):
     def __init__(self, stencil, boundary_object):
         assert isinstance(boundary_object, QuadraticBounceBack)
 
-        self._dtype = BasicType(boundary_object.data_type).c_name
+        self._dtype = create_type(boundary_object.data_type).c_name
         self._blocks = "const shared_ptr<StructuredBlockForest>&, IBlock&)>"
         super(QuadraticBounceBackAdditionalDataHandler, self).__init__(stencil=stencil)
 
@@ -251,11 +258,11 @@ class OutflowAdditionalDataHandler(AdditionalDataHandler):
         self._normal_direction = boundary_object.normal_direction
         self._field_name = field_name
         self._target = target
-        self._dtype = BasicType(boundary_object.data_type).c_name
+        self._dtype = create_type(boundary_object.data_type).c_name
         if pdfs_data_type is None:
             self._pdfs_data_type = "real_t"
         else:
-            pdfs_data_type = BasicType(pdfs_data_type)
+            pdfs_data_type = create_type(pdfs_data_type)
             self._pdfs_data_type = pdfs_data_type.c_name
 
         self._streaming_pattern = boundary_object.streaming_pattern
diff --git a/python/lbmpy_walberla/boundary_collection.py b/python/lbmpy_walberla/boundary_collection.py
index 3830d8bb4..ccb98ace5 100644
--- a/python/lbmpy_walberla/boundary_collection.py
+++ b/python/lbmpy_walberla/boundary_collection.py
@@ -7,7 +7,7 @@ from pystencils_walberla.jinja_filters import add_pystencils_filters_to_jinja_en
 from lbmpy.advanced_streaming import Timestep, is_inplace
 
 from pystencils_walberla.kernel_selection import KernelCallNode
-from lbmpy_walberla.alternating_sweeps import EvenIntegerCondition, OddIntegerCondition, TimestepTrackerMapping
+from lbmpy_walberla.alternating_sweeps import EvenIntegerCondition, OddIntegerCondition
 from lbmpy_walberla.additional_data_handler import default_additional_data_handler
 
 from pystencils import Target
diff --git a/python/lbmpy_walberla/function_generator.py b/python/lbmpy_walberla/function_generator.py
index 8e3d552c2..4a54fb19a 100644
--- a/python/lbmpy_walberla/function_generator.py
+++ b/python/lbmpy_walberla/function_generator.py
@@ -1,4 +1,4 @@
-from pystencils_walberla.kernel_selection import KernelCallNode, KernelFamily, HighLevelInterfaceSpec
+from pystencils_walberla.kernel_selection import KernelFamily, HighLevelInterfaceSpec
 
 
 def kernel_family_function_generator(class_name: str, kernel_family: KernelFamily,
diff --git a/python/lbmpy_walberla/packing_kernels.py b/python/lbmpy_walberla/packing_kernels.py
index 8a8728031..9f3e5114e 100644
--- a/python/lbmpy_walberla/packing_kernels.py
+++ b/python/lbmpy_walberla/packing_kernels.py
@@ -8,7 +8,8 @@ from jinja2 import Environment, PackageLoader, StrictUndefined
 
 from pystencils import Assignment, CreateKernelConfig, create_kernel, Field, FieldType, fields, Target
 from pystencils.stencil import offset_to_direction_string
-from pystencils.typing import TypedSymbol
+from pystencils import TypedSymbol
+from pystencils.typing import create_type
 from pystencils.stencil import inverse_direction
 from pystencils.bit_masks import flag_cond
 
@@ -17,6 +18,7 @@ from lbmpy.advanced_streaming.communication import _extend_dir
 from lbmpy.enums import Stencil
 from lbmpy.stencils import LBStencil
 
+from pystencils_walberla.compat import IS_PYSTENCILS_2, custom_type, target_string
 from pystencils_walberla.cmake_integration import CodeGenerationContext
 from pystencils_walberla.kernel_selection import KernelFamily, KernelCallNode, SwitchNode
 from pystencils_walberla.jinja_filters import add_pystencils_filters_to_jinja_env
@@ -52,7 +54,7 @@ def generate_packing_kernels(generation_context: CodeGenerationContext, class_na
         'class_name': class_name,
         'namespace': namespace,
         'nonuniform': nonuniform,
-        'target': target.name.lower(),
+        'target': target_string(target),
         'dtype': "float" if is_float else "double",
         'is_gpu': target == Target.GPU,
         'kernels': kernels,
@@ -93,7 +95,11 @@ class PackingKernelsCodegen:
         self.inplace = is_inplace(streaming_pattern)
         self.class_name = class_name
         self.config = config
-        self.data_type = config.data_type['pdfs'].numpy_dtype
+
+        if IS_PYSTENCILS_2:
+            self.data_type = create_type(config.default_dtype).numpy_dtype
+        else:
+            self.data_type = config.data_type['pdfs'].numpy_dtype
 
         self.src_field = src_field if src_field else fields(f'pdfs_src({stencil.Q}) :{self.data_type}[{stencil.D}D]')
         self.dst_field = dst_field if dst_field else fields(f'pdfs_dst({stencil.Q}) :{self.data_type}[{stencil.D}D]')
@@ -277,7 +283,7 @@ class PackingKernelsCodegen:
         function_name = f'unpackRedistribute_{dir_string}' + timestep_suffix(timestep)
         iteration_slice = tuple(slice(None, None, 2) for _ in range(self.dim))
         config = CreateKernelConfig(function_name=function_name, iteration_slice=iteration_slice,
-                                    data_type=self.data_type, ghost_layers=0, allow_double_writes=True,
+                                    data_type=self.data_type, allow_double_writes=True,
                                     cpu_openmp=self.config.cpu_openmp, target=self.config.target)
 
         return create_kernel(assignments, config=config)
@@ -315,7 +321,7 @@ class PackingKernelsCodegen:
             assignments.append(Assignment(buffer(i), acc))
 
         iteration_slice = tuple(slice(None, None, 2) for _ in range(self.dim))
-        config = replace(self.config, iteration_slice=iteration_slice, ghost_layers=0)
+        config = replace(self.config, iteration_slice=iteration_slice)
 
         ast = create_kernel(assignments, config=config)
         ast.function_name = f'packPartialCoalescence_{dir_string}' + timestep_suffix(timestep)
@@ -427,7 +433,7 @@ class PackingKernelsCodegen:
 
     def _construct_directionwise_kernel_family(self, create_ast_callback):
         subtrees = []
-        direction_symbol = TypedSymbol('dir', dtype='stencil::Direction')
+        direction_symbol = TypedSymbol('dir', dtype=custom_type('stencil::Direction'))
         for t in get_timesteps(self.streaming_pattern):
             cases_dict = dict()
             for comm_dir in self.full_stencil:
diff --git a/python/lbmpy_walberla/storage_specification.py b/python/lbmpy_walberla/storage_specification.py
index 60fd96d24..d5eba4c83 100644
--- a/python/lbmpy_walberla/storage_specification.py
+++ b/python/lbmpy_walberla/storage_specification.py
@@ -10,6 +10,7 @@ from lbmpy import LBMConfig, LBMOptimisation
 from lbmpy.advanced_streaming import is_inplace, get_accessor, Timestep
 from lbmpy.methods import AbstractLbMethod
 
+from pystencils_walberla.compat import get_default_dtype, target_string
 from pystencils_walberla.cmake_integration import CodeGenerationContext
 from pystencils_walberla.jinja_filters import add_pystencils_filters_to_jinja_env
 from pystencils_walberla.utility import config_from_context
@@ -35,7 +36,7 @@ def generate_lbm_storage_specification(generation_context: CodeGenerationContext
     # Packing kernels should never be vectorised
     config = replace(config, cpu_vectorize_info=None)
 
-    default_dtype = config.data_type.default_factory()
+    default_dtype = get_default_dtype(config) 
     if issubclass(default_dtype.numpy_dtype.type, np.float64):
         data_type_string = "double"
     elif issubclass(default_dtype.numpy_dtype.type, np.float32):
@@ -107,7 +108,7 @@ def generate_lbm_storage_specification(generation_context: CodeGenerationContext
         'odd_write': _get_access_list(odd_write, stencil.D),
 
         'nonuniform': nonuniform,
-        'target': target.name.lower(),
+        'target': target_string(target),
         'dtype': data_type_string,
         'is_gpu': target == Target.GPU,
         'kernels': kernels,
diff --git a/python/lbmpy_walberla/sweep_collection.py b/python/lbmpy_walberla/sweep_collection.py
index bc8bdda49..d1adf2172 100644
--- a/python/lbmpy_walberla/sweep_collection.py
+++ b/python/lbmpy_walberla/sweep_collection.py
@@ -8,6 +8,7 @@ from pystencils import Target, create_kernel
 from pystencils.config import CreateKernelConfig
 from pystencils.field import Field
 from pystencils.simp import add_subexpressions_for_field_reads
+from pystencils.typing import create_type
 
 from lbmpy.advanced_streaming import is_inplace, get_accessor, Timestep
 from lbmpy.creationfunctions import LbmCollisionRule, LBMConfig, LBMOptimisation
@@ -18,6 +19,7 @@ from lbmpy.updatekernels import create_lbm_kernel, create_stream_only_kernel
 from pystencils_walberla.kernel_selection import KernelCallNode, KernelFamily
 from pystencils_walberla.utility import config_from_context
 from pystencils_walberla import generate_sweep_collection
+from pystencils_walberla.compat import IS_PYSTENCILS_2, get_default_dtype
 from lbmpy_walberla.utility import create_pdf_field
 
 from .alternating_sweeps import EvenIntegerCondition
@@ -41,10 +43,11 @@ def generate_lbm_sweep_collection(ctx, class_name: str, collision_rule: LbmColli
     # coordinates should be ordered in reverse direction i.e. zyx
     lb_method = collision_rule.method
 
-    if field_layout == 'fzyx':
-        config.cpu_vectorize_info['assume_inner_stride_one'] = True
-    elif field_layout == 'zyxf':
-        config.cpu_vectorize_info['assume_inner_stride_one'] = False
+    if not IS_PYSTENCILS_2:
+        if field_layout == 'fzyx':
+            config.cpu_vectorize_info['assume_inner_stride_one'] = True
+        elif field_layout == 'zyxf':
+            config.cpu_vectorize_info['assume_inner_stride_one'] = False
 
     src_field = lbm_optimisation.symbolic_field
     if not src_field:
@@ -73,7 +76,10 @@ def generate_lbm_sweep_collection(ctx, class_name: str, collision_rule: LbmColli
     function_generators.append(generator('stream', family("stream")))
     function_generators.append(generator('streamOnlyNoAdvancement', family("streamOnlyNoAdvancement")))
 
-    config_unoptimized = replace(config, cpu_vectorize_info=None, cpu_prepend_optimizations=[], cpu_blocking=None)
+    if IS_PYSTENCILS_2:
+        config_unoptimized = replace(config, cpu_optim=None)
+    else:
+        config_unoptimized = replace(config, cpu_vectorize_info=None, cpu_prepend_optimizations=[], cpu_blocking=None)
 
     setter_family = get_setter_family(class_name, lb_method, src_field, streaming_pattern, macroscopic_fields,
                                       config_unoptimized, set_pre_collision_pdfs)
@@ -107,7 +113,7 @@ class RefinementScaling:
 def lbm_kernel_family(class_name, kernel_name,
                       collision_rule, streaming_pattern, src_field, dst_field, config: CreateKernelConfig):
 
-    default_dtype = config.data_type.default_factory()
+    default_dtype = get_default_dtype(config)     
     if kernel_name == "streamCollide":
         def lbm_kernel(field_accessor, lb_stencil):
             return create_lbm_kernel(collision_rule, src_field, dst_field, field_accessor, data_type=default_dtype)
@@ -148,7 +154,10 @@ def lbm_kernel_family(class_name, kernel_name,
             update_rule = lbm_kernel(accessor, stencil)
             ast = create_kernel(update_rule, config=config)
             ast.function_name = 'kernel_' + kernel_name + timestep_suffix
-            ast.assumed_inner_stride_one = config.cpu_vectorize_info['assume_inner_stride_one']
+            if IS_PYSTENCILS_2:
+                ast.assumed_inner_stride_one = False
+            else:
+                ast.assumed_inner_stride_one = config.cpu_vectorize_info['assume_inner_stride_one']
             nodes.append(KernelCallNode(ast))
 
         tree = EvenIntegerCondition('timestep', nodes[0], nodes[1], parameter_dtype=np.uint8)
@@ -160,7 +169,10 @@ def lbm_kernel_family(class_name, kernel_name,
         update_rule = lbm_kernel(accessor, stencil)
         ast = create_kernel(update_rule, config=config)
         ast.function_name = 'kernel_' + kernel_name
-        ast.assumed_inner_stride_one = config.cpu_vectorize_info['assume_inner_stride_one']
+        if IS_PYSTENCILS_2:
+            ast.assumed_inner_stride_one = False
+        else:
+            ast.assumed_inner_stride_one = config.cpu_vectorize_info['assume_inner_stride_one']
         node = KernelCallNode(ast)
         family = KernelFamily(node, class_name, temporary_fields=temporary_fields, field_swaps=field_swaps)
 
@@ -173,7 +185,7 @@ def get_setter_family(class_name, lb_method, pdfs, streaming_pattern, macroscopi
     density = macroscopic_fields.get('density', 1.0)
     velocity = macroscopic_fields.get('velocity', [0.0] * dim)
 
-    default_dtype = config.data_type.default_factory()
+    default_dtype = get_default_dtype(config) 
 
     get_timestep = {"field_name": pdfs.name, "function": "getTimestep"}
     temporary_fields = ()
@@ -218,7 +230,7 @@ def get_getter_family(class_name, lb_method, pdfs, streaming_pattern, macroscopi
     if density is None and velocity is None:
         return None
 
-    default_dtype = config.data_type.default_factory()
+    default_dtype = get_default_dtype(config) 
 
     get_timestep = {"field_name": pdfs.name, "function": "getTimestep"}
     temporary_fields = ()
diff --git a/python/lbmpy_walberla/utility.py b/python/lbmpy_walberla/utility.py
index 75460d5f9..b050173e5 100644
--- a/python/lbmpy_walberla/utility.py
+++ b/python/lbmpy_walberla/utility.py
@@ -1,9 +1,10 @@
-import numpy as np
 from pystencils import CreateKernelConfig, fields
+from pystencils.typing import create_type
 
 from lbmpy.advanced_streaming import Timestep
 from lbmpy.stencils import LBStencil
 
+from pystencils_walberla.compat import get_default_dtype
 
 def timestep_suffix(timestep: Timestep):
     """ get the suffix as string for a timestep
@@ -15,7 +16,7 @@ def timestep_suffix(timestep: Timestep):
 
 
 def create_pdf_field(config: CreateKernelConfig, name: str, stencil: LBStencil, field_layout: str = 'fzyx'):
-    default_dtype = config.data_type.default_factory()
+    default_dtype = get_default_dtype(config) 
     data_type = default_dtype.numpy_dtype
     return fields(f'{name}({stencil.Q}) :{data_type}[{stencil.D}D]', layout=field_layout)
 
diff --git a/python/lbmpy_walberla/walberla_lbm_generation.py b/python/lbmpy_walberla/walberla_lbm_generation.py
index e264fb8bb..cf06c08c0 100644
--- a/python/lbmpy_walberla/walberla_lbm_generation.py
+++ b/python/lbmpy_walberla/walberla_lbm_generation.py
@@ -1,6 +1,7 @@
 # import warnings
 from typing import Callable, List
 
+from pystencils_walberla.compat import IS_PYSTENCILS_2
 
 import numpy as np
 import sympy as sp
@@ -12,20 +13,24 @@ from lbmpy.fieldaccess import CollideOnlyInplaceAccessor, StreamPullTwoFieldsAcc
 from lbmpy.relaxationrates import relaxation_rate_scaling
 from lbmpy.updatekernels import create_lbm_kernel, create_stream_only_kernel
 from pystencils import AssignmentCollection, create_kernel, Target
-from pystencils.astnodes import SympyAssignment
-from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, get_headers
-from pystencils.typing import BasicType, CastFunc, TypedSymbol
 from pystencils.field import Field
-from pystencils.node_collection import NodeCollection
 from pystencils.stencil import offset_to_direction_string
 from pystencils.sympyextensions import get_symmetric_part
-from pystencils.typing.transformations import add_types
 
 from pystencils_walberla.kernel_info import KernelInfo
 from pystencils_walberla.jinja_filters import add_pystencils_filters_to_jinja_env
 from pystencils_walberla.utility import config_from_context
 
-cpp_printer = CustomSympyPrinter()
+
+if not IS_PYSTENCILS_2:
+    from pystencils.node_collection import NodeCollection
+    from pystencils.astnodes import SympyAssignment
+    from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, get_headers
+    from pystencils.typing import BasicType, CastFunc, TypedSymbol
+    from pystencils.typing.transformations import add_types
+
+    cpp_printer = CustomSympyPrinter()
+
 REFINEMENT_SCALE_FACTOR = sp.Symbol("level_scale_factor")
 
 
@@ -162,6 +167,9 @@ def __lattice_model(generation_context, class_name, config, lb_method, stream_co
 def generate_lattice_model(generation_context, class_name, collision_rule, field_layout='fzyx', refinement_scaling=None,
                            target=Target.CPU, data_type=None, cpu_openmp=None, cpu_vectorize_info=None,
                            **create_kernel_params):
+    
+    if IS_PYSTENCILS_2:
+        raise NotImplementedError("Lattice Model code generation is not available with pystencils 2.0")
 
     config = config_from_context(generation_context, target=target, data_type=data_type,
                                  cpu_openmp=cpu_openmp, cpu_vectorize_info=cpu_vectorize_info, **create_kernel_params)
diff --git a/python/pystencils_walberla/boundary.py b/python/pystencils_walberla/boundary.py
index 7af79ed67..ba5c7a647 100644
--- a/python/pystencils_walberla/boundary.py
+++ b/python/pystencils_walberla/boundary.py
@@ -1,16 +1,17 @@
 import numpy as np
 from jinja2 import Environment, PackageLoader, StrictUndefined
-from pystencils import Field, FieldType, Target
+from pystencils import Field, FieldType, Target, TypedSymbol
+from pystencils.typing import create_type
 from pystencils.boundaries.boundaryhandling import create_boundary_kernel
 from pystencils.boundaries.createindexlist import numpy_data_type_for_boundary_object
-from pystencils.typing import TypedSymbol, create_type
 
+from pystencils_walberla.compat import KernelFunction
 from pystencils_walberla.utility import config_from_context, struct_from_numpy_dtype
 from pystencils_walberla.jinja_filters import add_pystencils_filters_to_jinja_env
 from pystencils_walberla.additional_data_handler import AdditionalDataHandler
 from pystencils_walberla.kernel_selection import (
     KernelFamily, AbstractKernelSelectionNode, KernelCallNode, HighLevelInterfaceSpec)
-from pystencils.astnodes import KernelFunction
+from pystencils_walberla.compat import IS_PYSTENCILS_2, target_string, get_default_dtype
 
 
 def generate_boundary(generation_context,
@@ -44,13 +45,19 @@ def generate_boundary(generation_context,
     config = config_from_context(generation_context, target=target, data_type=data_type, cpu_openmp=cpu_openmp,
                                  **create_kernel_params)
     create_kernel_params = config.__dict__
-    del create_kernel_params['target']
-    del create_kernel_params['index_fields']
-    del create_kernel_params['default_number_int']
-    del create_kernel_params['skip_independence_check']
+    create_kernel_params.pop('target', None)
+    create_kernel_params.pop('index_fields', None)
+    create_kernel_params.pop('index_field', None)
+    create_kernel_params.pop('default_number_int', None)
+    create_kernel_params.pop('index_dtype', None)
+    create_kernel_params.pop('default_dtype', None)
+    create_kernel_params.pop('skip_independence_check', None)
 
     if field_data_type is None:
-        field_data_type = config.data_type[field_name].numpy_dtype
+        if IS_PYSTENCILS_2:
+            field_data_type = config.default_dtype
+        else:
+            field_data_type = config.data_type[field_name].numpy_dtype
 
     index_struct_dtype = numpy_data_type_for_boundary_object(boundary_object, dim)
 
@@ -85,7 +92,8 @@ def generate_boundary(generation_context,
     if additional_data_handler is None:
         additional_data_handler = AdditionalDataHandler(stencil=neighbor_stencil)
 
-    default_dtype = config.data_type.default_factory()
+    default_dtype = get_default_dtype(config) 
+
     is_float = True if issubclass(default_dtype.numpy_dtype.type, np.float32) else False
 
     context = {
@@ -96,7 +104,7 @@ def generate_boundary(generation_context,
         'StructName': struct_name,
         'StructDeclaration': struct_from_numpy_dtype(struct_name, index_struct_dtype),
         'dim': dim,
-        'target': target.name.lower(),
+        'target': target_string(target),
         'namespace': namespace,
         'inner_or_boundary': boundary_object.inner_or_boundary,
         'single_link': boundary_object.single_link,
diff --git a/python/pystencils_walberla/compat.py b/python/pystencils_walberla/compat.py
new file mode 100644
index 000000000..ba3b018aa
--- /dev/null
+++ b/python/pystencils_walberla/compat.py
@@ -0,0 +1,93 @@
+from pystencils import __version__ as ps_version
+
+#   Determine if we're running pystencils 1.x or 2.x
+version_tokes = ps_version.split(".")
+
+PS_VERSION = int(version_tokes[0])
+
+IS_PYSTENCILS_2 = PS_VERSION == 2
+
+if IS_PYSTENCILS_2:
+    #   pystencils 2.x
+
+    from typing import Any
+    from enum import Enum, auto
+
+    from pystencils import DEFAULTS, Target, create_type
+    from pystencils.types import PsType, PsDereferencableType, PsCustomType
+    from pystencils import KernelFunction
+    from pystencils.backend.emission import emit_code, CAstPrinter
+
+    def get_base_type(dtype: PsType):
+        while isinstance(dtype, PsDereferencableType):
+            dtype = dtype.base_type
+        return dtype
+
+    BasicType = PsType
+
+    SHAPE_DTYPE = DEFAULTS.index_dtype
+
+    def custom_type(typename: str):
+        return PsCustomType(typename)
+    
+    def get_default_dtype(config):
+        return create_type(config.default_dtype) 
+
+    class Backend(Enum):
+        C = auto()
+        CUDA = auto()
+
+    def generate_c(
+        kfunc: KernelFunction,
+        signature_only: bool = False,
+        dialect: Any = None,
+        custom_backend=None,
+        with_globals=False,
+    ) -> str:
+
+        assert not with_globals
+        assert custom_backend is None
+
+        if signature_only:
+            return CAstPrinter().print_signature(kfunc)
+        else:
+            return emit_code(kfunc)
+
+    def backend_printer(**kwargs):
+        return CAstPrinter()
+
+    def get_headers(kfunc: KernelFunction) -> set[str]:
+        return kfunc.required_headers
+
+    def target_string(target: Target) -> str:
+        if target.is_cpu():
+            return "cpu"
+        elif target.is_gpu():
+            return "cpu"
+        else:
+            raise Exception("Invalid target.")
+
+    def get_supported_instruction_sets():
+        return ()
+
+else:
+    #   pystencils 1.x
+
+    from pystencils import Target
+    from pystencils.typing import get_base_type, BasicType
+    from pystencils.typing.typed_sympy import SHAPE_DTYPE
+    from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets
+    from pystencils.enums import Backend
+    from pystencils.backends.cbackend import generate_c, get_headers, CustomSympyPrinter, KernelFunction
+    
+    def custom_type(typename: str):
+        return typename
+    
+    def get_default_dtype(config):
+        return config.data_type.default_factory()
+
+    def backend_printer(**kwargs):
+        return CustomSympyPrinter()
+
+    def target_string(target: Target) -> str:
+        return target.name.lower()
diff --git a/python/pystencils_walberla/jinja_filters.py b/python/pystencils_walberla/jinja_filters.py
index 6d05bf8ff..a32b2dd62 100644
--- a/python/pystencils_walberla/jinja_filters.py
+++ b/python/pystencils_walberla/jinja_filters.py
@@ -7,9 +7,9 @@ except ImportError:
 from collections.abc import Iterable
 import sympy as sp
 
-from pystencils import Target, Backend
-from pystencils.backends.cbackend import generate_c
-from pystencils.typing import TypedSymbol, get_base_type
+from pystencils_walberla.compat import get_base_type, generate_c, Backend, IS_PYSTENCILS_2
+
+from pystencils import Target, TypedSymbol, Field
 from pystencils.field import FieldType
 from pystencils.sympyextensions import prod
 
@@ -74,7 +74,7 @@ def make_field_type(dtype, f_size, is_gpu):
         return f"field::GhostLayerField<{dtype}, {f_size}>"
 
 
-def field_type(field, is_gpu=False):
+def field_type(field: Field, is_gpu=False):
     dtype = get_base_type(field.dtype)
     f_size = get_field_fsize(field)
     return make_field_type(dtype, f_size, is_gpu)
@@ -106,7 +106,7 @@ def get_field_stride(param):
         assert len(additional_strides) == field.index_dimensions
         f_stride_name = stride_names[-1]
         strides.extend([f"{type_str}({e} * {f_stride_name})" for e in reversed(additional_strides)])
-    return strides[param.symbol.coordinate]
+    return strides[param.coordinate if IS_PYSTENCILS_2 else param.symbol.coordinate]
 
 
 def generate_declaration(kernel_info, target=Target.CPU):
@@ -268,7 +268,7 @@ def generate_call(ctx, kernel, ghost_layers_to_include=0, cell_interval=None, st
     assert isinstance(ghost_layers_to_include, str) or ghost_layers_to_include >= 0
     ast_params = kernel.parameters
     vec_info = ctx.get('cpu_vectorize_info', None)
-    instruction_set = kernel.get_ast_attr('instruction_set')
+    instruction_set = kernel.get_ast_attr('instruction_set', default=None)
     if vec_info:
         assume_inner_stride_one = vec_info['assume_inner_stride_one']
         assume_aligned = vec_info['assume_aligned']
@@ -278,7 +278,7 @@ def generate_call(ctx, kernel, ghost_layers_to_include=0, cell_interval=None, st
         assume_aligned = False
 
     cpu_openmp = ctx.get('cpu_openmp', False)
-    kernel_ghost_layers = kernel.get_ast_attr('ghost_layers')
+    kernel_ghost_layers = kernel.get_ast_attr('ghost_layers', default=None)
 
     ghost_layers_to_include = sp.sympify(ghost_layers_to_include)
     if kernel_ghost_layers is None:
@@ -353,7 +353,7 @@ def generate_call(ctx, kernel, ghost_layers_to_include=0, cell_interval=None, st
             type_str = param.symbol.dtype.c_name
             kernel_call_lines.append(f"const {type_str} {param.symbol.name} = {casted_stride};")
         elif param.is_field_shape:
-            coord = param.symbol.coordinate
+            coord = param.coordinate if IS_PYSTENCILS_2 else param.symbol.coordinate
             field = param.fields[0]
             type_str = param.symbol.dtype.c_name
             shape = f"{type_str}({get_end_coordinates(field)[coord]})"
@@ -423,7 +423,7 @@ def generate_function_collection_call(ctx, kernel, parameters_to_ignore=(),
         parameters.append(ghost_layers)
 
     if is_gpu and "gpuStream" not in parameters_to_ignore:
-        parameters.append(f"gpuStream")
+        parameters.append("gpuStream")
 
     return ", ".join(parameters)
 
@@ -591,7 +591,9 @@ def generate_members(ctx, kernel_infos, parameters_to_ignore=None, only_fields=F
             original_field_name = field_name[:-len('_tmp')]
             f_size = get_field_fsize(f)
             field_type = make_field_type(get_base_type(f.dtype), f_size, is_gpu)
-            result.append(temporary_fieldMemberTemplate.format(type=field_type, original_field_name=original_field_name))
+            result.append(
+                temporary_fieldMemberTemplate.format(type=field_type, original_field_name=original_field_name)
+            )
 
     for kernel_info in kernel_infos:
         if hasattr(kernel_info, 'varying_parameters'):
@@ -634,13 +636,13 @@ def generate_plain_parameter_list(ctx, kernel_info, cell_interval=None, ghost_la
         if type(ghost_layers) in (int, ):
             result.append(f"const cell_idx_t ghost_layers = {ghost_layers}")
         else:
-            result.append(f"const cell_idx_t ghost_layers")
+            result.append("const cell_idx_t ghost_layers")
 
     if is_gpu:
         if stream is not None:
             result.append(f"gpuStream_t stream = {stream}")
         else:
-            result.append(f"gpuStream_t stream")
+            result.append("gpuStream_t stream")
 
     return ", ".join(result)
 
diff --git a/python/pystencils_walberla/kernel_info.py b/python/pystencils_walberla/kernel_info.py
index 586c05abe..09a8af5e2 100644
--- a/python/pystencils_walberla/kernel_info.py
+++ b/python/pystencils_walberla/kernel_info.py
@@ -1,35 +1,38 @@
 from functools import reduce
 
-from pystencils import Target
-
-from pystencils.backends.cbackend import get_headers
-from pystencils.backends.cuda_backend import CudaSympyPrinter
-from pystencils.typing.typed_sympy import SHAPE_DTYPE
-from pystencils.typing import TypedSymbol
+from pystencils import Target, TypedSymbol
 
 from pystencils_walberla.utility import merge_sorted_lists
+from pystencils_walberla.compat import backend_printer, SHAPE_DTYPE, KernelFunction, IS_PYSTENCILS_2
 
 
 # TODO KernelInfo and KernelFamily should have same interface
 class KernelInfo:
-    def __init__(self, ast, temporary_fields=(), field_swaps=(), varying_parameters=()):
+    def __init__(self, ast: KernelFunction, temporary_fields=(), field_swaps=(), varying_parameters=()):
         self.ast = ast
         self.temporary_fields = tuple(temporary_fields)
         self.field_swaps = tuple(field_swaps)
         self.varying_parameters = tuple(varying_parameters)
         self.parameters = ast.get_parameters()  # cache parameters here
 
+        if ast.target == Target.GPU and IS_PYSTENCILS_2:
+            #   TODO
+            raise NotImplementedError("Generating GPU kernels is not yet supported with pystencils 2.0")
+
     @property
     def fields_accessed(self):
         return self.ast.fields_accessed
 
-    def get_ast_attr(self, name):
+    def get_ast_attr(self, name, default=None):
         """Returns the value of an attribute of the AST managed by this KernelInfo.
         For compatibility with KernelFamily."""
-        return self.ast.__getattribute__(name)
+        try:
+            return self.ast.__getattribute__(name)
+        except AttributeError:
+            return self.ast.metadata.get(name, default)
 
     def get_headers(self):
-        all_headers = [list(get_headers(self.ast))]
+        all_headers = [list(self.ast.required_headers)]
         return reduce(merge_sorted_lists, all_headers)
 
     def generate_kernel_invocation_code(self, **kwargs):
@@ -53,7 +56,7 @@ class KernelInfo:
                 "Please only use kernels for generic field sizes!"
 
             indexing_dict = ast.indexing.call_parameters(spatial_shape_symbols)
-            sp_printer_c = CudaSympyPrinter()
+            sp_printer_c = backend_printer()
             block = tuple(sp_printer_c.doprint(e) for e in indexing_dict['block'])
             grid = tuple(sp_printer_c.doprint(e) for e in indexing_dict['grid'])
 
diff --git a/python/pystencils_walberla/kernel_selection.py b/python/pystencils_walberla/kernel_selection.py
index ad8a99867..6fa95460d 100644
--- a/python/pystencils_walberla/kernel_selection.py
+++ b/python/pystencils_walberla/kernel_selection.py
@@ -4,12 +4,9 @@ from collections import OrderedDict
 from functools import reduce
 from jinja2.filters import do_indent
 from pystencils import Target, TypedSymbol
-from pystencils.backends.cbackend import get_headers
-from pystencils.backends.cuda_backend import CudaSympyPrinter
-from pystencils.typing.typed_sympy import SHAPE_DTYPE
 
 from pystencils_walberla.utility import merge_lists_of_symbols, merge_sorted_lists
-
+from pystencils_walberla.compat import backend_printer, get_headers, SHAPE_DTYPE, IS_PYSTENCILS_2
 
 """
 
@@ -162,6 +159,10 @@ class KernelCallNode(AbstractKernelSelectionNode):
         self.ast = ast
         self.parameters = ast.get_parameters()  # cache parameters here
 
+        if ast.target == Target.GPU and IS_PYSTENCILS_2:
+            #   TODO
+            raise NotImplementedError("Generating GPU kernels is not yet supported with pystencils 2.0")
+
     @property
     def selection_parameters(self) -> Set[TypedSymbol]:
         return set()
@@ -190,7 +191,7 @@ class KernelCallNode(AbstractKernelSelectionNode):
                 "Please only use kernels for generic field sizes!"
 
             indexing_dict = ast.indexing.call_parameters(spatial_shape_symbols)
-            sp_printer_c = CudaSympyPrinter()
+            sp_printer_c = backend_printer()
             block = tuple(sp_printer_c.doprint(e) for e in indexing_dict['block'])
             grid = tuple(sp_printer_c.doprint(e) for e in indexing_dict['grid'])
 
@@ -269,13 +270,19 @@ class KernelFamily:
 
         self._ast_attrs = dict()
 
-    def get_ast_attr(self, name):
+    def get_ast_attr(self, name, default=None):
         """Returns the value of an attribute of the ASTs managed by this KernelFamily only
         if it is the same in all ASTs."""
+        def getattr(ast, name):
+            try:
+                return ast.__getattribute__(name)
+            except AttributeError:
+                return ast.metadata.get(name, default)
+
         if name not in self._ast_attrs:
-            attr = self.representative_ast.__getattribute__(name)
+            attr = getattr(self.representative_ast, name)
             for ast in self.all_asts:
-                if ast.__getattribute__(name) != attr:
+                if getattr(ast, name) != attr:
                     raise ValueError(f'Inconsistency in kernel family: Attribute {name} was different in {ast}!')
             self._ast_attrs[name] = attr
         return self._ast_attrs[name]
diff --git a/python/pystencils_walberla/pack_info.py b/python/pystencils_walberla/pack_info.py
index df84c71c7..706814f4d 100644
--- a/python/pystencils_walberla/pack_info.py
+++ b/python/pystencils_walberla/pack_info.py
@@ -6,9 +6,9 @@ from typing import Dict, Optional, Sequence, Tuple
 from jinja2 import Environment, PackageLoader, StrictUndefined
 
 from pystencils import Assignment, AssignmentCollection, Field, FieldType, Target, create_kernel
-from pystencils.backends.cbackend import get_headers
 from pystencils.stencil import inverse_direction, offset_to_direction_string
 
+from pystencils_walberla.compat import target_string, get_headers
 from pystencils_walberla.cmake_integration import CodeGenerationContext
 from pystencils_walberla.jinja_filters import add_pystencils_filters_to_jinja_env
 from pystencils_walberla.kernel_info import KernelInfo
@@ -192,7 +192,7 @@ def generate_pack_info(generation_context: CodeGenerationContext, class_name: st
         'fused_kernel': KernelInfo(fused_kernel),
         'elements_per_cell': elements_per_cell,
         'headers': get_headers(fused_kernel),
-        'target': config.target.name.lower(),
+        'target': target_string(config.target),
         'dtype': dtype,
         'field_name': field_names.pop(),
         'namespace': namespace,
diff --git a/python/pystencils_walberla/sweep.py b/python/pystencils_walberla/sweep.py
index f6e190fde..1f9eb960a 100644
--- a/python/pystencils_walberla/sweep.py
+++ b/python/pystencils_walberla/sweep.py
@@ -4,13 +4,12 @@ from jinja2 import Environment, PackageLoader, StrictUndefined
 
 from pystencils import Target, Assignment
 from pystencils import Field, create_kernel, create_staggered_kernel
-from pystencils.astnodes import KernelFunction
-from pystencils.typing import numpy_name_to_c
 
 from pystencils_walberla.cmake_integration import CodeGenerationContext
 from pystencils_walberla.jinja_filters import add_pystencils_filters_to_jinja_env
 from pystencils_walberla.kernel_selection import KernelCallNode, KernelFamily, HighLevelInterfaceSpec
 from pystencils_walberla.utility import config_from_context
+from pystencils_walberla.compat import target_string, IS_PYSTENCILS_2, KernelFunction
 
 
 def generate_sweep(generation_context: CodeGenerationContext, class_name: str, assignments: Sequence[Assignment],
@@ -73,8 +72,10 @@ def generate_sweep(generation_context: CodeGenerationContext, class_name: str, a
     generate_selective_sweep(generation_context, class_name, selection_tree, target=target, namespace=namespace,
                              field_swaps=field_swaps, varying_parameters=varying_parameters,
                              inner_outer_split=inner_outer_split, ghost_layers_to_include=ghost_layers_to_include,
-                             cpu_vectorize_info=config.cpu_vectorize_info,
-                             cpu_openmp=config.cpu_openmp, max_threads=max_threads, block_offset=block_offset)
+                             cpu_vectorize_info=None if IS_PYSTENCILS_2 else config.cpu_vectorize_info,  # FIXME
+                             cpu_openmp=config.cpu_openmp,
+                             max_threads=max_threads,
+                             block_offset=block_offset)
 
 
 def generate_selective_sweep(generation_context, class_name, selection_tree, interface_mappings=(), target=None,
@@ -131,13 +132,13 @@ def generate_selective_sweep(generation_context, class_name, selection_tree, int
     parameters_to_ignore = None
     if isinstance(block_offset, Iterable):
         parameters_to_ignore = [b.name for b in block_offset]
-        block_offset = tuple((b.name, numpy_name_to_c(b.dtype.numpy_dtype.name)) for b in block_offset)
+        block_offset = tuple((b.name, b.dtype.c_name) for b in block_offset)
 
     jinja_context = {
         'kernel': kernel_family,
         'namespace': namespace,
         'class_name': class_name,
-        'target': target.name.lower(),
+        'target': target_string(target),
         'field': representative_field,
         'ghost_layers_to_include': ghost_layers_to_include,
         'inner_outer_split': inner_outer_split,
@@ -199,7 +200,7 @@ def generate_sweep_collection(generation_context: CodeGenerationContext, class_n
         'namespace': namespace,
         'class_name': class_name,
         'headers': headers,
-        'target': target.name.lower(),
+        'target': target_string(target),
         'parameter_scaling': parameter_scaling,
     }
 
diff --git a/python/pystencils_walberla/templates/Boundary.tmpl.cpp b/python/pystencils_walberla/templates/Boundary.tmpl.cpp
index 644202ba6..08c2ced38 100644
--- a/python/pystencils_walberla/templates/Boundary.tmpl.cpp
+++ b/python/pystencils_walberla/templates/Boundary.tmpl.cpp
@@ -25,9 +25,9 @@
 {%- endif %}
 
 
-{% if target == 'cpu' -%}
+{% if target is equalto 'cpu' -%}
 #define FUNC_PREFIX
-{%- elif target == 'gpu' -%}
+{%- elif target is equalto 'gpu' -%}
 #define FUNC_PREFIX __global__
 {%- endif %}
 
diff --git a/python/pystencils_walberla/utility.py b/python/pystencils_walberla/utility.py
index f19a09974..d7f9ef36b 100644
--- a/python/pystencils_walberla/utility.py
+++ b/python/pystencils_walberla/utility.py
@@ -4,13 +4,15 @@ from typing import Union, Dict, DefaultDict
 import warnings
 
 from pystencils import CreateKernelConfig, Target
-from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets
+from pystencils.typing import create_type
 from pystencils.boundaries.createindexlist import boundary_index_array_coordinate_names, direction_member_name
-from pystencils.typing import BasicType, create_type, get_base_type
+
+from pystencils_walberla.compat import get_supported_instruction_sets, BasicType
 
 from lbmpy import LBStencil
 
 from pystencils_walberla.cmake_integration import CodeGenerationContext
+from pystencils_walberla.compat import PS_VERSION
 
 HEADER_EXTENSIONS = {'.h', '.hpp'}
 
@@ -139,9 +141,15 @@ def config_from_context(ctx: CodeGenerationContext, target: Target = Target.CPU,
     cpu_vectorize_info['assume_sufficient_line_padding'] = cpu_vectorize_info.get('assume_sufficient_line_padding',
                                                                                   False)
 
-    config = CreateKernelConfig(target=target, data_type=data_type, default_number_float=data_type,
+    if PS_VERSION == 1:
+        config = CreateKernelConfig(target=target, data_type=data_type, default_number_float=data_type,
                                 cpu_openmp=cpu_openmp, cpu_vectorize_info=cpu_vectorize_info,
                                 **kwargs)
+    else:
+        config = CreateKernelConfig(target=target, default_dtype=data_type,
+                                cpu_openmp=cpu_openmp,
+                                cpu_vectorize_info=None,  # FIXME
+                                **kwargs)
 
     return config
 
@@ -170,7 +178,7 @@ def merge_sorted_lists(lx, ly, sort_key=lambda x: x, identity_check_key=None):
             return [x] + recursive_merge(lx_intern, ly_intern, ix_intern + 1, iy_intern)
         else:
             return [y] + recursive_merge(lx_intern, ly_intern, ix_intern, iy_intern + 1)
-    return recursive_merge(lx, ly, 0, 0)
+    return recursive_merge(list(lx), list(ly), 0, 0)
 
 
 def merge_lists_of_symbols(lists):
diff --git a/tests/lbm_generated/InterpolationNoSlip.py b/tests/lbm_generated/InterpolationNoSlip.py
index 891892f43..d1a18dd4d 100644
--- a/tests/lbm_generated/InterpolationNoSlip.py
+++ b/tests/lbm_generated/InterpolationNoSlip.py
@@ -47,7 +47,7 @@ with CodeGeneration() as ctx:
     generate_lbm_package(ctx, name="InterpolationNoSlip",
                          collision_rule=collision_rule,
                          lbm_config=lbm_config, lbm_optimisation=lbm_opt,
-                         nonuniform=True, boundaries=[no_slip, no_slip_bouzidi, no_slip_quadraticbb, ubb],
+                         nonuniform=False, boundaries=[no_slip, no_slip_bouzidi, no_slip_quadraticbb, ubb],
                          macroscopic_fields=macroscopic_fields, data_type=data_type,
                          set_pre_collision_pdfs=False)
 
-- 
GitLab