From 5bdabcf6f147241e69431809686c910c5451e58f Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 24 Jan 2025 09:29:27 +0100
Subject: [PATCH] adapt JIT compilers to DynamicType in fields. Enable
 typechecking for the field module.

---
 docs/source/contributing/dev-workflow.md   |   8 +-
 mypy.ini                                   |   3 +
 src/pystencils/field.py                    | 684 +++++++++++++++------
 src/pystencils/jit/cpu_extension_module.py |  89 +--
 src/pystencils/jit/gpu_cupy.py             |  14 +-
 tests/frontend/test_field.py               |   4 +-
 tests/runtime/test_datahandling.py         |   2 +-
 7 files changed, 574 insertions(+), 230 deletions(-)

diff --git a/docs/source/contributing/dev-workflow.md b/docs/source/contributing/dev-workflow.md
index 2aee09ba2..fe8b70e77 100644
--- a/docs/source/contributing/dev-workflow.md
+++ b/docs/source/contributing/dev-workflow.md
@@ -118,10 +118,10 @@ mypy src/pystencils
 ::::
 
 :::{note}
-Type checking is currently restricted to the `codegen`, `jit`, `backend`, and `types` modules,
-since most code in the remaining modules is significantly older and is not comprehensively
-type-annotated. As more modules are updated with type annotations, this list will expand in the future.
-If you think a new module is ready to be type-checked, add an exception clause for it in the `mypy.ini` file.
+Type checking is currently restricted only to a few modules, which are listed in the `mypy.ini` file.
+Most code in the remaining modules is significantly older and is not comprehensively type-annotated.
+As more modules are updated with type annotations, this list will expand in the future.
+If you think a new module is ready to be type-checked, add an exception clause to `mypy.ini`.
 :::
 
 ## Running the Test Suite
diff --git a/mypy.ini b/mypy.ini
index 08f073f7c..c8a7195e2 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -17,6 +17,9 @@ ignore_errors = False
 [mypy-pystencils.jit.*]
 ignore_errors = False
 
+[mypy-pystencils.field]
+ignore_errors = False
+
 [mypy-pystencils.sympyextensions.typed_sympy]
 ignore_errors = False
 
diff --git a/src/pystencils/field.py b/src/pystencils/field.py
index 826c551a7..2a74824c4 100644
--- a/src/pystencils/field.py
+++ b/src/pystencils/field.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import functools
 import hashlib
 import operator
@@ -14,14 +16,18 @@ from sympy.core.cache import cacheit
 from .defaults import DEFAULTS
 from .alignedarray import aligned_empty
 from .spatial_coordinates import x_staggered_vector, x_vector
-from .stencil import direction_string_to_offset, inverse_direction, offset_to_direction_string
+from .stencil import (
+    direction_string_to_offset,
+    inverse_direction,
+    offset_to_direction_string,
+)
 from .types import PsType, PsStructType, create_type
 from .sympyextensions.typed_sympy import TypedSymbol, DynamicType
 from .sympyextensions import is_integer_sequence
 from .types import UserTypeSpec
 
 
-__all__ = ['Field', 'fields', 'FieldType', 'Field']
+__all__ = ["Field", "fields", "FieldType", "Field"]
 
 
 class FieldType(Enum):
@@ -63,7 +69,10 @@ class FieldType(Enum):
     @staticmethod
     def is_staggered(field):
         assert isinstance(field, Field)
-        return field.field_type == FieldType.STAGGERED or field.field_type == FieldType.STAGGERED_FLUX
+        return (
+            field.field_type == FieldType.STAGGERED
+            or field.field_type == FieldType.STAGGERED_FLUX
+        )
 
     @staticmethod
     def is_staggered_flux(field):
@@ -132,8 +141,15 @@ class Field:
     """
 
     @staticmethod
-    def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE, index_dimensions=0, 
-                       layout='numpy', index_shape=None, field_type=FieldType.GENERIC) -> 'Field':
+    def create_generic(
+        field_name,
+        spatial_dimensions,
+        dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE,
+        index_dimensions=0,
+        layout="numpy",
+        index_shape=None,
+        field_type=FieldType.GENERIC,
+    ) -> "Field":
         """
         Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes
 
@@ -159,30 +175,45 @@ class Field:
             layout = spatial_layout_string_to_tuple(layout, dim=spatial_dimensions)
 
         total_dimensions = spatial_dimensions + index_dimensions
+        shape: tuple[TypedSymbol | int, ...]
+
         if index_shape is None or len(index_shape) == 0:
-            shape = tuple([
-                TypedSymbol(DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE) 
-                for i in range(total_dimensions)
-            ])
+            shape = tuple(
+                [
+                    TypedSymbol(
+                        DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE
+                    )
+                    for i in range(total_dimensions)
+                ]
+            )
         else:
             shape = tuple(
                 [
-                    TypedSymbol(DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE)
+                    TypedSymbol(
+                        DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE
+                    )
                     for i in range(spatial_dimensions)
-                ] + list(index_shape)
+                ]
+                + list(index_shape)
             )
 
-        strides = tuple([
-            TypedSymbol(DEFAULTS.field_stride_name(field_name, i), DynamicType.INDEX_TYPE) 
-            for i in range(total_dimensions)
-        ])
+        strides: tuple[TypedSymbol | int, ...] = tuple(
+            [
+                TypedSymbol(
+                    DEFAULTS.field_stride_name(field_name, i), DynamicType.INDEX_TYPE
+                )
+                for i in range(total_dimensions)
+            ]
+        )
 
         if not isinstance(dtype, DynamicType):
             dtype = create_type(dtype)
 
         if isinstance(dtype, PsStructType):
             if index_dimensions != 0:
-                raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
+                raise ValueError(
+                    "Structured arrays/fields are not allowed to have an index dimension"
+                )
             shape += (1,)
             strides += (1,)
 
@@ -192,8 +223,12 @@ class Field:
         return Field(field_name, field_type, dtype, layout, shape, strides)
 
     @staticmethod
-    def create_from_numpy_array(field_name: str, array: np.ndarray, index_dimensions: int = 0,
-                                field_type=FieldType.GENERIC) -> 'Field':
+    def create_from_numpy_array(
+        field_name: str,
+        array: np.ndarray,
+        index_dimensions: int = 0,
+        field_type=FieldType.GENERIC,
+    ) -> "Field":
         """Creates a field based on the layout, data type, and shape of a given numpy array.
 
         Kernels created for these kind of fields can only be called with arrays of the same layout, shape and type.
@@ -206,7 +241,9 @@ class Field:
         """
         spatial_dimensions = len(array.shape) - index_dimensions
         if spatial_dimensions < 1:
-            raise ValueError("Too many index dimensions. At least one spatial dimension required")
+            raise ValueError(
+                "Too many index dimensions. At least one spatial dimension required"
+            )
 
         full_layout = get_layout_of_array(array)
         spatial_layout = tuple([i for i in full_layout if i < spatial_dimensions])
@@ -218,19 +255,28 @@ class Field:
         numpy_dtype = np.dtype(array.dtype)
         if numpy_dtype.fields is not None:
             if index_dimensions != 0:
-                raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
+                raise ValueError(
+                    "Structured arrays/fields are not allowed to have an index dimension"
+                )
             shape += (1,)
             strides += (1,)
         if field_type == FieldType.STAGGERED and index_dimensions == 0:
             raise ValueError("A staggered field needs at least one index dimension")
 
-        return Field(field_name, field_type, array.dtype, spatial_layout, shape, strides)
+        return Field(
+            field_name, field_type, array.dtype, spatial_layout, shape, strides
+        )
 
     @staticmethod
-    def create_fixed_size(field_name: str, shape: Tuple[int, ...], index_dimensions: int = 0,
-                          dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE, layout: str = 'numpy',
-                          strides: Optional[Sequence[int]] = None,
-                          field_type=FieldType.GENERIC) -> 'Field':
+    def create_fixed_size(
+        field_name: str,
+        shape: tuple[int, ...],
+        index_dimensions: int = 0,
+        dtype: UserTypeSpec = np.float64,
+        layout: str | tuple[int, ...] = "numpy",
+        strides: Optional[Sequence[int]] = None,
+        field_type=FieldType.GENERIC,
+    ) -> "Field":
         """
         Creates a field with fixed sizes i.e. can be called only with arrays of the same size and layout
 
@@ -243,34 +289,58 @@ class Field:
             strides: strides in bytes or None to automatically compute them from shape (assuming no padding)
             field_type: kind of field
         """
+        if isinstance(dtype, DynamicType):
+            raise ValueError(
+                "Parameter `dtype` to `Field.create_fixed_size` does not accept `DynamicType`."
+            )
+
         spatial_dimensions = len(shape) - index_dimensions
         assert spatial_dimensions >= 1
 
         if isinstance(layout, str):
-            layout = layout_string_to_tuple(layout, spatial_dimensions + index_dimensions)
+            layout = layout_string_to_tuple(
+                layout, spatial_dimensions + index_dimensions
+            )
+
+        dtype = create_type(dtype)
+
+        shape_tuple = tuple(int(s) for s in shape)
+        strides_tuple: tuple[int, ...]
 
-        shape = tuple(int(s) for s in shape)
         if strides is None:
-            strides = compute_strides(shape, layout)
+            strides_tuple = compute_strides(shape_tuple, layout)
         else:
-            assert len(strides) == len(shape)
-            strides = tuple([s // np.dtype(dtype).itemsize for s in strides])
+            assert len(strides) == len(shape_tuple)
+            np_type = dtype.numpy_dtype
 
-        if not isinstance(dtype, DynamicType):
-            dtype = create_type(dtype)
+            if np_type is None:
+                raise ValueError(
+                    f"Cannot create fixed-size field of data type {dtype} which has no NumPy representation."
+                )
+
+            strides_tuple = tuple([s // np_type.itemsize for s in strides])
 
         if isinstance(dtype, PsStructType):
             if index_dimensions != 0:
-                raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
-            shape += (1,)
-            strides += (1,)
+                raise ValueError(
+                    "Structured arrays/fields are not allowed to have an index dimension"
+                )
+            shape_tuple += (1,)
+            strides_tuple += (1,)
         if field_type == FieldType.STAGGERED and index_dimensions == 0:
             raise ValueError("A staggered field needs at least one index dimension")
 
         spatial_layout = list(layout)
         for i in range(spatial_dimensions, len(layout)):
             spatial_layout.remove(i)
-        return Field(field_name, field_type, dtype, tuple(spatial_layout), shape, strides)
+        return Field(
+            field_name,
+            field_type,
+            dtype,
+            tuple(spatial_layout),
+            shape_tuple,
+            strides_tuple,
+        )
 
     def __init__(
         self,
@@ -279,14 +349,16 @@ class Field:
         dtype: UserTypeSpec | DynamicType,
         layout: tuple[int, ...],
         shape,
-        strides
+        strides,
     ):
         """Do not use directly. Use static create* methods"""
         self._field_name = field_name
         assert isinstance(field_type, FieldType)
         assert len(shape) == len(strides)
         self.field_type = field_type
-        self._dtype: PsType | DynamicType = create_type(dtype) if not isinstance(dtype, DynamicType) else dtype
+        self._dtype: PsType | DynamicType = (
+            create_type(dtype) if not isinstance(dtype, DynamicType) else dtype
+        )
         self._layout = normalize_layout(layout)
         self.shape = shape
         self.strides = strides
@@ -298,9 +370,23 @@ class Field:
 
     def new_field_with_different_name(self, new_name):
         if self.has_fixed_shape:
-            return Field(new_name, self.field_type, self._dtype, self._layout, self.shape, self.strides)
+            return Field(
+                new_name,
+                self.field_type,
+                self._dtype,
+                self._layout,
+                self.shape,
+                self.strides,
+            )
         else:
-            return Field(new_name, self.field_type, self.dtype, self.layout, self.shape, self.strides)
+            return Field(
+                new_name,
+                self.field_type,
+                self.dtype,
+                self.layout,
+                self.shape,
+                self.strides,
+            )
 
     @property
     def spatial_dimensions(self) -> int:
@@ -327,7 +413,7 @@ class Field:
 
     @property
     def spatial_shape(self) -> Tuple[int, ...]:
-        return self.shape[:self.spatial_dimensions]
+        return self.shape[: self.spatial_dimensions]
 
     @property
     def has_fixed_shape(self):
@@ -343,7 +429,7 @@ class Field:
 
     @property
     def spatial_strides(self):
-        return self.strides[:self.spatial_dimensions]
+        return self.strides[: self.spatial_dimensions]
 
     @property
     def index_strides(self):
@@ -362,15 +448,15 @@ class Field:
 
     def __repr__(self):
         if any(isinstance(s, sp.Symbol) for s in self.spatial_shape):
-            spatial_shape_str = f'{self.spatial_dimensions}d'
+            spatial_shape_str = f"{self.spatial_dimensions}d"
         else:
-            spatial_shape_str = ','.join(str(i) for i in self.spatial_shape)
-        index_shape_str = ','.join(str(i) for i in self.index_shape)
+            spatial_shape_str = ",".join(str(i) for i in self.spatial_shape)
+        index_shape_str = ",".join(str(i) for i in self.index_shape)
 
         if self.index_shape:
-            return f'{self._field_name}({index_shape_str}): {self.dtype}[{spatial_shape_str}]'
+            return f"{self._field_name}({index_shape_str}): {self.dtype}[{spatial_shape_str}]"
         else:
-            return f'{self._field_name}: {self.dtype}[{spatial_shape_str}]'
+            return f"{self._field_name}: {self.dtype}[{spatial_shape_str}]"
 
     def __str__(self):
         return self.name
@@ -391,12 +477,26 @@ class Field:
         elif len(index_shape) == 1:
             return sp.Matrix([self(i) for i in range(index_shape[0])])
         elif len(index_shape) == 2:
-            return sp.Matrix([[self(i, j) for j in range(index_shape[1])] for i in range(index_shape[0])])
+            return sp.Matrix(
+                [
+                    [self(i, j) for j in range(index_shape[1])]
+                    for i in range(index_shape[0])
+                ]
+            )
         elif len(index_shape) == 3:
-            return sp.Array([[[self(i, j, k) for k in range(index_shape[2])]
-                              for j in range(index_shape[1])] for i in range(index_shape[0])])
+            return sp.Array(
+                [
+                    [
+                        [self(i, j, k) for k in range(index_shape[2])]
+                        for j in range(index_shape[1])
+                    ]
+                    for i in range(index_shape[0])
+                ]
+            )
         else:
-            raise NotImplementedError("center_vector is not implemented for more than 3 index dimensions")
+            raise NotImplementedError(
+                "center_vector is not implemented for more than 3 index dimensions"
+            )
 
     @property
     def center(self):
@@ -412,12 +512,20 @@ class Field:
         if self.index_dimensions == 0:
             return sp.Matrix([self.__getitem__(offset)])
         elif self.index_dimensions == 1:
-            return sp.Matrix([self.__getitem__(offset)(i) for i in range(self.index_shape[0])])
+            return sp.Matrix(
+                [self.__getitem__(offset)(i) for i in range(self.index_shape[0])]
+            )
         elif self.index_dimensions == 2:
-            return sp.Matrix([[self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])]
-                              for i in range(self.index_shape[0])])
+            return sp.Matrix(
+                [
+                    [self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])]
+                    for i in range(self.index_shape[0])
+                ]
+            )
         else:
-            raise NotImplementedError("neighbor_vector is not implemented for more than 2 index dimensions")
+            raise NotImplementedError(
+                "neighbor_vector is not implemented for more than 2 index dimensions"
+            )
 
     def __getitem__(self, offset):
         if type(offset) is np.ndarray:
@@ -427,7 +535,9 @@ class Field:
         if type(offset) is not tuple:
             offset = (offset,)
         if len(offset) != self.spatial_dimensions:
-            raise ValueError(f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}")
+            raise ValueError(
+                f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}"
+            )
         return Field.Access(self, offset)
 
     def absolute_access(self, offset, index):
@@ -450,7 +560,9 @@ class Field:
             offset = tuple(direction_string_to_offset(offset, self.spatial_dimensions))
             offset = tuple([o * sp.Rational(1, 2) for o in offset])
         if len(offset) != self.spatial_dimensions:
-            raise ValueError(f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}")
+            raise ValueError(
+                f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}"
+            )
 
         prefactor = 1
         neighbor_vec = [0] * len(offset)
@@ -464,25 +576,33 @@ class Field:
             if FieldType.is_staggered_flux(self):
                 prefactor = -1
         if neighbor not in self.staggered_stencil:
-            raise ValueError(f"{offset_orig} is not a valid neighbor for the {self.staggered_stencil_name} stencil")
+            raise ValueError(
+                f"{offset_orig} is not a valid neighbor for the {self.staggered_stencil_name} stencil"
+            )
 
         offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec))
 
         idx = self.staggered_stencil.index(neighbor)
 
-        if self.index_dimensions == 1:  # this field stores a scalar value at each staggered position
+        if (
+            self.index_dimensions == 1
+        ):  # this field stores a scalar value at each staggered position
             if index is not None:
                 raise ValueError("Cannot specify an index for a scalar staggered field")
             return prefactor * Field.Access(self, offset, (idx,))
         else:  # this field stores a vector or tensor at each staggered position
             if index is None:
-                raise ValueError(f"Wrong number of indices: Got 0, expected {self.index_dimensions - 1}")
+                raise ValueError(
+                    f"Wrong number of indices: Got 0, expected {self.index_dimensions - 1}"
+                )
             if type(index) is np.ndarray:
                 index = tuple(index)
             if type(index) is not tuple:
                 index = (index,)
             if self.index_dimensions != len(index) + 1:
-                raise ValueError(f"Wrong number of indices: Got {len(index)}, expected {self.index_dimensions - 1}")
+                raise ValueError(
+                    f"Wrong number of indices: Got {len(index)}, expected {self.index_dimensions - 1}"
+                )
 
             return prefactor * Field.Access(self, offset, (idx, *index))
 
@@ -493,30 +613,54 @@ class Field:
         if self.index_dimensions == 1:
             return sp.Matrix([self.staggered_access(offset)])
         elif self.index_dimensions == 2:
-            return sp.Matrix([self.staggered_access(offset, i) for i in range(self.index_shape[1])])
+            return sp.Matrix(
+                [self.staggered_access(offset, i) for i in range(self.index_shape[1])]
+            )
         elif self.index_dimensions == 3:
-            return sp.Matrix([[self.staggered_access(offset, (i, k)) for k in range(self.index_shape[2])]
-                              for i in range(self.index_shape[1])])
+            return sp.Matrix(
+                [
+                    [
+                        self.staggered_access(offset, (i, k))
+                        for k in range(self.index_shape[2])
+                    ]
+                    for i in range(self.index_shape[1])
+                ]
+            )
         else:
-            raise NotImplementedError("staggered_vector_access is not implemented for more than 3 index dimensions")
+            raise NotImplementedError(
+                "staggered_vector_access is not implemented for more than 3 index dimensions"
+            )
 
     @property
     def staggered_stencil(self):
         assert FieldType.is_staggered(self)
         stencils = {
-            2: {
-                2: ["W", "S"],  # D2Q5
-                4: ["W", "S", "SW", "NW"]  # D2Q9
-            },
+            2: {2: ["W", "S"], 4: ["W", "S", "SW", "NW"]},  # D2Q5  # D2Q9
             3: {
                 3: ["W", "S", "B"],  # D3Q7
                 7: ["W", "S", "B", "BSW", "TSW", "BNW", "TNW"],  # D3Q15
                 9: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS"],  # D3Q19
-                13: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS", "BSW", "TSW", "BNW", "TNW"]  # D3Q27
-            }
+                13: [
+                    "W",
+                    "S",
+                    "B",
+                    "SW",
+                    "NW",
+                    "BW",
+                    "TW",
+                    "BS",
+                    "TS",
+                    "BSW",
+                    "TSW",
+                    "BNW",
+                    "TNW",
+                ],  # D3Q27
+            },
         }
         if not self.index_shape[0] in stencils[self.spatial_dimensions]:
-            raise ValueError(f"No known stencil has {self.index_shape[0]} staggered points")
+            raise ValueError(
+                f"No known stencil has {self.index_shape[0]} staggered points"
+            )
         return stencils[self.spatial_dimensions][self.index_shape[0]]
 
     @property
@@ -529,13 +673,15 @@ class Field:
         return Field.Access(self, center)(*args, **kwargs)
 
     def hashable_contents(self):
-        return (self._layout,
-                self.shape,
-                self.strides,
-                self.field_type,
-                self._field_name,
-                self.latex_name,
-                self._dtype)
+        return (
+            self._layout,
+            self.shape,
+            self.strides,
+            self.field_type,
+            self._field_name,
+            self.latex_name,
+            self._dtype,
+        )
 
     def __hash__(self):
         return hash(self.hashable_contents())
@@ -547,36 +693,53 @@ class Field:
 
     @property
     def physical_coordinates(self):
-        if hasattr(self.coordinate_transform, '__call__'):
-            return self.coordinate_transform(self.coordinate_origin + x_vector(self.spatial_dimensions))
+        if hasattr(self.coordinate_transform, "__call__"):
+            return self.coordinate_transform(
+                self.coordinate_origin + x_vector(self.spatial_dimensions)
+            )
         else:
-            return self.coordinate_transform @ (self.coordinate_origin + x_vector(self.spatial_dimensions))
+            return self.coordinate_transform @ (
+                self.coordinate_origin + x_vector(self.spatial_dimensions)
+            )
 
     @property
     def physical_coordinates_staggered(self):
-        return self.coordinate_transform @ \
-            (self.coordinate_origin + x_staggered_vector(self.spatial_dimensions))
+        return self.coordinate_transform @ (
+            self.coordinate_origin + x_staggered_vector(self.spatial_dimensions)
+        )
 
     def index_to_physical(self, index_coordinates: sp.Matrix, staggered=False):
         if staggered:
-            index_coordinates = sp.Matrix([0.5] * len(self.coordinate_origin)) + index_coordinates
-        if hasattr(self.coordinate_transform, '__call__'):
+            index_coordinates = (
+                sp.Matrix([0.5] * len(self.coordinate_origin)) + index_coordinates
+            )
+        if hasattr(self.coordinate_transform, "__call__"):
             return self.coordinate_transform(self.coordinate_origin + index_coordinates)
         else:
-            return self.coordinate_transform @ (self.coordinate_origin + index_coordinates)
+            return self.coordinate_transform @ (
+                self.coordinate_origin + index_coordinates
+            )
 
     def physical_to_index(self, physical_coordinates: sp.Matrix, staggered=False):
-        if hasattr(self.coordinate_transform, '__call__'):
-            if hasattr(self.coordinate_transform, 'inv'):
-                return self.coordinate_transform.inv()(physical_coordinates) - self.coordinate_origin
+        if hasattr(self.coordinate_transform, "__call__"):
+            if hasattr(self.coordinate_transform, "inv"):
+                return (
+                    self.coordinate_transform.inv()(physical_coordinates)
+                    - self.coordinate_origin
+                )
             else:
-                idx = sp.Matrix(sp.symbols(f'index_coordinates:{self.ndim}', real=True))
+                idx = sp.Matrix(sp.symbols(f"index_coordinates:{self.ndim}", real=True))
                 rtn = sp.solve(self.index_to_physical(idx) - physical_coordinates, idx)
-                assert rtn, f'Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}'
+                assert (
+                    rtn
+                ), f"Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}"
                 return rtn
 
         else:
-            rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin
+            rtn = (
+                self.coordinate_transform.inv() @ physical_coordinates
+                - self.coordinate_origin
+            )
         if staggered:
             rtn = sp.Matrix([i - 0.5 for i in rtn])
 
@@ -605,18 +768,40 @@ class Field:
             >>> central_y_component.at_index(0)  # change component
             v_C^0
         """
+
         _iterable = False  # see https://i10git.cs.fau.de/pycodegen/pystencils/-/merge_requests/166#note_10680
 
         __match_args__ = ("field", "offsets", "index")
 
+        #   for the type checker
+        _field: Field
+        _offsets: tuple[int | sp.Basic, ...]
+        _offsetName: str
+        _superscript: None | str
+        _index: tuple[int | sp.Basic, ...] | str
+        _indirect_addressing_fields: set[Field]
+        _is_absolute_access: bool
+
         def __new__(cls, name, *args, **kwargs):
             obj = Field.Access.__xnew_cached_(cls, name, *args, **kwargs)
             return obj
 
-        def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None, is_absolute_access=False, dtype=None):
+        def __new_stage2__(  # type: ignore
+            self,
+            field: Field,
+            offsets: tuple[int, ...] = (0, 0, 0),
+            idx: None | tuple[int, ...] | str = None,
+            is_absolute_access: bool = False,
+            dtype: PsType | None = None,
+        ):
             field_name = field.name
             offsets_and_index = (*offsets, *idx) if idx is not None else offsets
-            constant_offsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsets_and_index])
+            constant_offsets = not any(
+                [
+                    isinstance(o, sp.Basic) and not o.is_Integer
+                    for o in offsets_and_index
+                ]
+            )
 
             if not idx:
                 idx = tuple([0] * field.index_dimensions)
@@ -630,31 +815,36 @@ class Field:
                 else:
                     idx_str = ",".join([str(e) for e in idx])
                     superscript = idx_str
-                if field.has_fixed_index_shape and not isinstance(field.dtype, PsStructType):
+                if field.has_fixed_index_shape and not isinstance(
+                    field.dtype, PsStructType
+                ):
                     for i, bound in zip(idx, field.index_shape):
                         if i >= bound:
                             raise ValueError("Field index out of bounds")
             else:
-                offset_name = hashlib.md5(pickle.dumps(offsets_and_index)).hexdigest()[:12]
+                offset_name = hashlib.md5(pickle.dumps(offsets_and_index)).hexdigest()[
+                    :12
+                ]
                 superscript = None
 
             symbol_name = f"{field_name}_{offset_name}"
             if superscript is not None:
                 symbol_name += "^" + superscript
 
+            obj: Field.Access
             if dtype is not None:
                 obj = super(Field.Access, self).__xnew__(self, symbol_name, dtype)
             else:
                 obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype)
 
             obj._field = field
-            obj._offsets = []
+            _offsets: list[sp.Basic | int] = []
             for o in offsets:
                 if isinstance(o, sp.Basic):
-                    obj._offsets.append(o)
+                    _offsets.append(o)
                 else:
-                    obj._offsets.append(int(o))
-            obj._offsets = tuple(sp.sympify(obj._offsets))
+                    _offsets.append(int(o))
+            obj._offsets = tuple(sp.sympify(_offsets))
             obj._offsetName = offset_name
             obj._superscript = superscript
             obj._index = idx
@@ -662,19 +852,33 @@ class Field:
             obj._indirect_addressing_fields = set()
             for e in chain(obj._offsets, obj._index):
                 if isinstance(e, sp.Basic):
-                    obj._indirect_addressing_fields.update(a.field for a in e.atoms(Field.Access))
+                    obj._indirect_addressing_fields.update(
+                        a.field for a in e.atoms(Field.Access)
+                    )
 
             obj._is_absolute_access = is_absolute_access
             return obj
 
         def __getnewargs__(self):
-            return self.field, self.offsets, self.index, self.is_absolute_access, self.dtype
+            return (
+                self.field,
+                self.offsets,
+                self.index,
+                self.is_absolute_access,
+                self.dtype,
+            )
 
         def __getnewargs_ex__(self):
-            return (self.field, self.offsets, self.index, self.is_absolute_access, self.dtype), {}
+            return (
+                self.field,
+                self.offsets,
+                self.index,
+                self.is_absolute_access,
+                self.dtype,
+            ), {}
 
         # noinspection SpellCheckingInspection
-        __xnew__ = staticmethod(__new_stage2__)
+        __xnew__ = staticmethod(__new_stage2__)  # type: ignore
         # noinspection SpellCheckingInspection
         __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
 
@@ -688,22 +892,34 @@ class Field:
                 idx = ()
 
             if len(idx) != self.field.index_dimensions:
-                raise ValueError(f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}")
+                raise ValueError(
+                    f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}"
+                )
             if len(idx) == 1 and isinstance(idx[0], str):
                 struct_type = self.field.dtype
                 assert isinstance(struct_type, PsStructType)
                 dtype = struct_type.get_member(idx[0]).dtype
-                return Field.Access(self.field, self._offsets, idx,
-                                    is_absolute_access=self.is_absolute_access, dtype=dtype)
+                return Field.Access(
+                    self.field,
+                    self._offsets,
+                    idx,
+                    is_absolute_access=self.is_absolute_access,
+                    dtype=dtype,
+                )
             else:
-                return Field.Access(self.field, self._offsets, idx,
-                                    is_absolute_access=self.is_absolute_access, dtype=self.dtype)
+                return Field.Access(
+                    self.field,
+                    self._offsets,
+                    idx,
+                    is_absolute_access=self.is_absolute_access,
+                    dtype=self.dtype,
+                )
 
         def __getitem__(self, *idx):
             return self.__call__(*idx)
 
         @property
-        def field(self) -> 'Field':
+        def field(self) -> "Field":
             """Field that the Access points to"""
             return self._field
 
@@ -715,7 +931,7 @@ class Field:
         @property
         def required_ghost_layers(self) -> int:
             """Largest spatial distance that is accessed."""
-            return int(np.max(np.abs(self._offsets)))
+            return int(np.max(np.abs(self._offsets)))  # type: ignore
 
         @property
         def nr_of_coordinates(self):
@@ -737,7 +953,7 @@ class Field:
             """Value of index coordinates as tuple."""
             return self._index
 
-        def neighbor(self, coord_id: int, offset: int) -> 'Field.Access':
+        def neighbor(self, coord_id: int, offset: int) -> "Field.Access":
             """Returns a new Access with changed spatial coordinates.
 
             Args:
@@ -751,10 +967,15 @@ class Field:
             """
             offset_list = list(self.offsets)
             offset_list[coord_id] += offset
-            return Field.Access(self.field, tuple(offset_list), self.index,
-                                is_absolute_access=self.is_absolute_access, dtype=self.dtype)
+            return Field.Access(
+                self.field,
+                tuple(offset_list),
+                self.index,
+                is_absolute_access=self.is_absolute_access,
+                dtype=self.dtype,
+            )
 
-        def get_shifted(self, *shift) -> 'Field.Access':
+        def get_shifted(self, *shift) -> "Field.Access":
             """Returns a new Access with changed spatial coordinates
 
             Example:
@@ -762,13 +983,15 @@ class Field:
                 >>> f[0,0].get_shifted(1, 1)
                 f_NE
             """
-            return Field.Access(self.field,
-                                tuple(a + b for a, b in zip(shift, self.offsets)),
-                                self.index,
-                                is_absolute_access=self.is_absolute_access,
-                                dtype=self.dtype)
+            return Field.Access(
+                self.field,
+                tuple(a + b for a, b in zip(shift, self.offsets)),
+                self.index,
+                is_absolute_access=self.is_absolute_access,
+                dtype=self.dtype,
+            )
 
-        def at_index(self, *idx_tuple) -> 'Field.Access':
+        def at_index(self, *idx_tuple) -> "Field.Access":
             """Returns new Access with changed index.
 
             Example:
@@ -776,15 +999,22 @@ class Field:
                 >>> f(0).at_index(8)
                 f_C^8
             """
-            return Field.Access(self.field, self.offsets, idx_tuple,
-                                is_absolute_access=self.is_absolute_access, dtype=self.dtype)
+            return Field.Access(
+                self.field,
+                self.offsets,
+                idx_tuple,
+                is_absolute_access=self.is_absolute_access,
+                dtype=self.dtype,
+            )
 
         def _eval_subs(self, old, new):
-            return Field.Access(self.field,
-                                tuple(sp.sympify(a).subs(old, new) for a in self.offsets),
-                                tuple(sp.sympify(a).subs(old, new) for a in self.index),
-                                is_absolute_access=self.is_absolute_access,
-                                dtype=self.dtype)
+            return Field.Access(
+                self.field,
+                tuple(sp.sympify(a).subs(old, new) for a in self.offsets),
+                tuple(sp.sympify(a).subs(old, new) for a in self.index),
+                is_absolute_access=self.is_absolute_access,
+                dtype=self.dtype,
+            )
 
         @property
         def is_absolute_access(self) -> bool:
@@ -792,30 +1022,43 @@ class Field:
             return self._is_absolute_access
 
         @property
-        def indirect_addressing_fields(self) -> Set['Field']:
+        def indirect_addressing_fields(self) -> Set["Field"]:
             """Returns a set of fields that the access depends on.
 
-             e.g. f[index_field[1, 0]], the outer access to f depends on index_field
-             """
+            e.g. f[index_field[1, 0]], the outer access to f depends on index_field
+            """
             return self._indirect_addressing_fields
 
         def _hashable_content(self):
             super_class_contents = super(Field.Access, self)._hashable_content()
-            return (super_class_contents, self._field.hashable_contents(), *self._index,
-                    *self._offsets, self._is_absolute_access)
+            return (
+                super_class_contents,
+                self._field.hashable_contents(),
+                *self._index,
+                *self._offsets,
+                self._is_absolute_access,
+            )
 
         def _staggered_offset(self, offsets, index):
             assert FieldType.is_staggered(self._field)
             neighbor = self._field.staggered_stencil[index]
-            neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions)
-            return [(o + sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)]
+            neighbor = direction_string_to_offset(
+                neighbor, self._field.spatial_dimensions
+            )
+            return [
+                (o + sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)
+            ]
 
         def _latex(self, _):
             n = self._field.latex_name if self._field.latex_name else self._field.name
             offset_str = ",".join([sp.latex(o) for o in self.offsets])
             if FieldType.is_staggered(self._field):
-                offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
-                                       for i in range(len(self.offsets))])
+                offset_str = ",".join(
+                    [
+                        sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
+                        for i in range(len(self.offsets))
+                    ]
+                )
             if self.is_absolute_access:
                 offset_str = f"\\mathbf{offset_str}"
             elif self.field.spatial_dimensions > 1:
@@ -836,8 +1079,12 @@ class Field:
             n = self._field.latex_name if self._field.latex_name else self._field.name
             offset_str = ",".join([sp.latex(o) for o in self.offsets])
             if FieldType.is_staggered(self._field):
-                offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
-                                       for i in range(len(self.offsets))])
+                offset_str = ",".join(
+                    [
+                        sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
+                        for i in range(len(self.offsets))
+                    ]
+                )
             if self.is_absolute_access:
                 offset_str = f"[abs]{offset_str}"
 
@@ -853,8 +1100,13 @@ class Field:
                     return f"{n}[{offset_str}]"
 
 
-def fields(description=None, index_dimensions=0, layout=None,
-           field_type=FieldType.GENERIC, **kwargs) -> Union[Field, List[Field]]:
+def fields(
+    description=None,
+    index_dimensions=0,
+    layout=None,
+    field_type=FieldType.GENERIC,
+    **kwargs,
+) -> Union[Field, List[Field]]:
     """Creates pystencils fields from a string description.
 
     Examples:
@@ -888,31 +1140,63 @@ def fields(description=None, index_dimensions=0, layout=None,
     result = []
     if description:
         field_descriptions, dtype, shape = _parse_description(description)
-        layout = 'numpy' if layout is None else layout
+        layout = "numpy" if layout is None else layout
         for field_name, idx_shape in field_descriptions:
             if field_name in kwargs:
                 arr = kwargs[field_name]
-                idx_shape_of_arr = () if not len(idx_shape) else arr.shape[-len(idx_shape):]
+                idx_shape_of_arr = (
+                    () if not len(idx_shape) else arr.shape[-len(idx_shape):]
+                )
                 assert idx_shape_of_arr == idx_shape
-                f = Field.create_from_numpy_array(field_name, kwargs[field_name], index_dimensions=len(idx_shape),
-                                                  field_type=field_type)
+                f = Field.create_from_numpy_array(
+                    field_name,
+                    kwargs[field_name],
+                    index_dimensions=len(idx_shape),
+                    field_type=field_type,
+                )
             elif isinstance(shape, tuple):
-                f = Field.create_fixed_size(field_name, shape + idx_shape, dtype=dtype,
-                                            index_dimensions=len(idx_shape), layout=layout, field_type=field_type)
+                f = Field.create_fixed_size(
+                    field_name,
+                    shape + idx_shape,
+                    dtype=dtype,
+                    index_dimensions=len(idx_shape),
+                    layout=layout,
+                    field_type=field_type,
+                )
             elif isinstance(shape, int):
-                f = Field.create_generic(field_name, spatial_dimensions=shape, dtype=dtype,
-                                         index_shape=idx_shape, layout=layout, field_type=field_type)
+                f = Field.create_generic(
+                    field_name,
+                    spatial_dimensions=shape,
+                    dtype=dtype,
+                    index_shape=idx_shape,
+                    layout=layout,
+                    field_type=field_type,
+                )
             elif shape is None:
-                f = Field.create_generic(field_name, spatial_dimensions=2, dtype=dtype,
-                                         index_shape=idx_shape, layout=layout, field_type=field_type)
+                f = Field.create_generic(
+                    field_name,
+                    spatial_dimensions=2,
+                    dtype=dtype,
+                    index_shape=idx_shape,
+                    layout=layout,
+                    field_type=field_type,
+                )
             else:
                 assert False
             result.append(f)
     else:
-        assert layout is None, "Layout can not be specified when creating Field from numpy array"
+        assert (
+            layout is None
+        ), "Layout can not be specified when creating Field from numpy array"
         for field_name, arr in kwargs.items():
-            result.append(Field.create_from_numpy_array(field_name, arr, index_dimensions=index_dimensions,
-                                                        field_type=field_type))
+            result.append(
+                Field.create_from_numpy_array(
+                    field_name,
+                    arr,
+                    index_dimensions=index_dimensions,
+                    field_type=field_type,
+                )
+            )
 
     if len(result) == 0:
         raise ValueError("Could not parse field description")
@@ -922,16 +1206,27 @@ def fields(description=None, index_dimensions=0, layout=None,
         return result
 
 
-def get_layout_from_strides(strides: Sequence[int], index_dimension_ids: Optional[List[int]] = None):
+def get_layout_from_strides(
+    strides: Sequence[int], index_dimension_ids: Optional[List[int]] = None
+):
     index_dimension_ids = [] if index_dimension_ids is None else index_dimension_ids
     coordinates = list(range(len(strides)))
-    relevant_strides = [stride for i, stride in enumerate(strides) if i not in index_dimension_ids]
-    result = [x for (y, x) in sorted(zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True)]
+    relevant_strides = [
+        stride for i, stride in enumerate(strides) if i not in index_dimension_ids
+    ]
+    result = [
+        x
+        for (y, x) in sorted(
+            zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True
+        )
+    ]
     return normalize_layout(result)
 
 
-def get_layout_of_array(arr: np.ndarray, index_dimension_ids: Optional[List[int]] = None):
-    """ Returns a list indicating the memory layout (linearization order) of the numpy array.
+def get_layout_of_array(
+    arr: np.ndarray, index_dimension_ids: Optional[List[int]] = None
+):
+    """Returns a list indicating the memory layout (linearization order) of the numpy array.
 
     Examples:
         >>> get_layout_of_array(np.zeros([3,3,3]))
@@ -948,7 +1243,9 @@ def get_layout_of_array(arr: np.ndarray, index_dimension_ids: Optional[List[int]
     return get_layout_from_strides(arr.strides, index_dimension_ids)
 
 
-def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0, **kwargs):
+def create_numpy_array_with_layout(
+    shape, layout, alignment=False, byte_offset=0, **kwargs
+):
     """Creates numpy array with given memory layout.
 
     Args:
@@ -972,7 +1269,10 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0
         if cur_layout[i] != layout[i]:
             index_to_swap_with = cur_layout.index(layout[i])
             swaps.append((i, index_to_swap_with))
-            cur_layout[i], cur_layout[index_to_swap_with] = cur_layout[index_to_swap_with], cur_layout[i]
+            cur_layout[i], cur_layout[index_to_swap_with] = (
+                cur_layout[index_to_swap_with],
+                cur_layout[i],
+            )
     assert tuple(cur_layout) == tuple(layout)
 
     shape = list(shape)
@@ -980,7 +1280,7 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0
         shape[a], shape[b] = shape[b], shape[a]
 
     if not alignment:
-        res = np.empty(shape, order='c', **kwargs)
+        res = np.empty(shape, order="c", **kwargs)
     else:
         res = aligned_empty(shape, alignment, byte_offset=byte_offset, **kwargs)
 
@@ -992,37 +1292,43 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0
 def spatial_layout_string_to_tuple(layout_str: str, dim: int) -> Tuple[int, ...]:
     if dim <= 0:
         raise ValueError("Dimensionality must be positive")
-    
+
     layout_str = layout_str.lower()
 
-    if layout_str in ('fzyx', 'zyxf', 'soa', 'aos'):
+    if layout_str in ("fzyx", "zyxf", "soa", "aos"):
         if dim > 3:
-            raise ValueError(f"Invalid spatial dimensionality for layout descriptor {layout_str}: May be at most 3.")
+            raise ValueError(
+                f"Invalid spatial dimensionality for layout descriptor {layout_str}: May be at most 3."
+            )
         return tuple(reversed(range(dim)))
-    
-    if layout_str in ('f', 'reverse_numpy'):
+
+    if layout_str in ("f", "reverse_numpy"):
         return tuple(reversed(range(dim)))
-    elif layout_str in ('c', 'numpy'):
+    elif layout_str in ("c", "numpy"):
         return tuple(range(dim))
     raise ValueError("Unknown layout descriptor " + layout_str)
 
 
-def layout_string_to_tuple(layout_str, dim):
+def layout_string_to_tuple(layout_str, dim) -> tuple[int, ...]:
     if dim <= 0:
         raise ValueError("Dimensionality must be positive")
-    
+
     layout_str = layout_str.lower()
-    if layout_str == 'fzyx' or layout_str == 'soa':
+    if layout_str == "fzyx" or layout_str == "soa":
         if dim > 4:
-            raise ValueError(f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4.")
+            raise ValueError(
+                f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4."
+            )
         return tuple(reversed(range(dim)))
-    elif layout_str == 'zyxf' or layout_str == 'aos':
+    elif layout_str == "zyxf" or layout_str == "aos":
         if dim > 4:
-            raise ValueError(f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4.")
+            raise ValueError(
+                f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4."
+            )
         return tuple(reversed(range(dim - 1))) + (dim - 1,)
-    elif layout_str == 'f' or layout_str == 'reverse_numpy':
+    elif layout_str == "f" or layout_str == "reverse_numpy":
         return tuple(reversed(range(dim)))
-    elif layout_str == 'c' or layout_str == 'numpy':
+    elif layout_str == "c" or layout_str == "numpy":
         return tuple(range(dim))
     raise ValueError("Unknown layout descriptor " + layout_str)
 
@@ -1057,7 +1363,8 @@ def compute_strides(shape, layout):
 
 # ---------------------------------------- Parsing of string in fields() function --------------------------------------
 
-field_description_regex = re.compile(r"""
+field_description_regex = re.compile(
+    r"""
     \s*                 # ignore leading white spaces
     (\w+)               # identifier is a sequence of alphanumeric characters, is stored in first group
     (?:                 # optional index specification e.g. (1, 4, 2)
@@ -1068,9 +1375,12 @@ field_description_regex = re.compile(r"""
         \s*
     )?
     \s*,?\s*             # ignore trailing white spaces and comma
-""", re.VERBOSE)
+""",
+    re.VERBOSE,
+)
 
-type_description_regex = re.compile(r"""
+type_description_regex = re.compile(
+    r"""
     \s*
     (\w+)?       # optional dtype
     \s*
@@ -1078,7 +1388,9 @@ type_description_regex = re.compile(r"""
         ([^\]]+)
     \]
     \s*
-""", re.VERBOSE | re.IGNORECASE)
+""",
+    re.VERBOSE | re.IGNORECASE,
+)
 
 
 def _parse_part1(d):
@@ -1104,19 +1416,19 @@ def _parse_description(description):
             else:
                 dtype = DynamicType.NUMERIC_TYPE
 
-            if size_info.endswith('d'):
+            if size_info.endswith("d"):
                 size_info = int(size_info[:-1])
             else:
                 size_info = tuple(int(e) for e in size_info.split(","))
-            
+
             return dtype, size_info
         else:
             raise ValueError("Could not parse field description")
 
-    if ':' in description:
-        field_description, field_info = description.split(':')
+    if ":" in description:
+        field_description, field_info = description.split(":")
     else:
-        field_description, field_info = description, 'float64[2D]'
+        field_description, field_info = description, "float64[2D]"
 
     fields_info = [e for e in _parse_part1(field_description)]
     if not field_info:
diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py
index befb033e6..48a7eebeb 100644
--- a/src/pystencils/jit/cpu_extension_module.py
+++ b/src/pystencils/jit/cpu_extension_module.py
@@ -1,6 +1,6 @@
 from __future__ import annotations
 
-from typing import Any
+from typing import Any, cast
 
 from os import path
 import hashlib
@@ -19,6 +19,7 @@ from ..types import (
     PsUnsignedIntegerType,
     PsSignedIntegerType,
     PsIeeeFloatType,
+    PsPointerType,
 )
 from ..types.quick import Fp, SInt, UInt
 from ..field import Field
@@ -121,9 +122,7 @@ class PsKernelExtensioNModule:
 
 def emit_call_wrapper(function_name: str, kernel: Kernel) -> str:
     builder = CallWrapperBuilder()
-
-    for p in kernel.parameters:
-        builder.extract_parameter(p)
+    builder.extract_params(kernel.parameters)
 
     # for c in kernel.constraints:
     #     builder.check_constraint(c)
@@ -199,7 +198,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
 """
 
     def __init__(self) -> None:
-        self._array_buffers: dict[Field, str] = dict()
+        self._buffer_types: dict[Field, PsType] = dict()
         self._array_extractions: dict[Field, str] = dict()
         self._array_frees: dict[Field, str] = dict()
 
@@ -220,9 +219,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
                 return "PyLong_AsUnsignedLong"
 
             case _:
-                raise ValueError(
-                    f"Don't know how to cast Python objects to {dtype}"
-                )
+                raise ValueError(f"Don't know how to cast Python objects to {dtype}")
 
     def _type_char(self, dtype: PsType) -> str | None:
         if isinstance(
@@ -233,37 +230,39 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
         else:
             return None
 
-    def extract_field(self, field: Field) -> str:
+    def get_field_buffer(self, field: Field):
+        """Get the Python buffer object for the given field."""
+        return f"buffer_{field.name}"
+
+    def extract_field(self, field: Field):
         """Adds an array, and returns the name of the underlying Py_Buffer."""
         if field not in self._array_extractions:
             extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=field.name)
+            actual_dtype = self._buffer_types[field]
 
             #   Check array type
-            type_char = self._type_char(field.dtype)
+            type_char = self._type_char(actual_dtype)
             if type_char is not None:
                 dtype_cond = f"buffer_{field.name}.format[0] == '{type_char}'"
                 extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
                     cond=dtype_cond,
                     what="data type",
                     name=field.name,
-                    expected=str(field.dtype),
+                    expected=str(actual_dtype),
                 )
 
             #   Check item size
-            itemsize = field.dtype.itemsize
+            itemsize = actual_dtype.itemsize
             item_size_cond = f"buffer_{field.name}.itemsize == {itemsize}"
             extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
                 cond=item_size_cond, what="itemsize", name=field.name, expected=itemsize
             )
 
-            self._array_buffers[field] = f"buffer_{field.name}"
             self._array_extractions[field] = extraction_code
 
             release_code = f"PyBuffer_Release(&buffer_{field.name});"
             self._array_frees[field] = release_code
 
-        return self._array_buffers[field]
-
     def extract_scalar(self, param: Parameter) -> str:
         if param not in self._scalar_extractions:
             extract_func = self._scalar_extractor(param.dtype)
@@ -279,7 +278,8 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
     def extract_array_assoc_var(self, param: Parameter) -> str:
         if param not in self._array_assoc_var_extractions:
             field = param.fields[0]
-            buffer = self.extract_field(field)
+            buffer = self.get_field_buffer(field)
+            buffer_dtype = self._buffer_types[field]
             code: str | None = None
 
             for prop in param.properties:
@@ -293,7 +293,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
                     case FieldStride(_, coord):
                         code = (
                             f"{param.dtype.c_string()} {param.name} = "
-                            f"{buffer}.strides[{coord}] / {field.dtype.itemsize};"
+                            f"{buffer}.strides[{coord}] / {buffer_dtype.itemsize};"
                         )
                         break
             assert code is not None
@@ -302,29 +302,48 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
 
         return param.name
 
-    def extract_parameter(self, param: Parameter):
-        if param.is_field_parameter:
-            self.extract_array_assoc_var(param)
-        else:
-            self.extract_scalar(param)
+    def extract_params(self, params: tuple[Parameter, ...]):
+        for param in params:
+            if ptr_props := param.get_properties(FieldBasePtr):
+                prop: FieldBasePtr = cast(FieldBasePtr, ptr_props.pop())
+                field = prop.field
+                actual_field_type: PsType
+
+                from .. import DynamicType
+
+                if isinstance(field.dtype, DynamicType):
+                    ptr_type = param.dtype
+                    assert isinstance(ptr_type, PsPointerType)
+                    actual_field_type = ptr_type.base_type
+                else:
+                    actual_field_type = field.dtype
+
+                self._buffer_types[prop.field] = actual_field_type
+                self.extract_field(prop.field)
+
+        for param in params:
+            if param.is_field_parameter:
+                self.extract_array_assoc_var(param)
+            else:
+                self.extract_scalar(param)
 
-#     def check_constraint(self, constraint: KernelParamsConstraint):
-#         variables = constraint.get_parameters()
+    #     def check_constraint(self, constraint: KernelParamsConstraint):
+    #         variables = constraint.get_parameters()
 
-#         for var in variables:
-#             self.extract_parameter(var)
+    #         for var in variables:
+    #             self.extract_parameter(var)
 
-#         cond = constraint.to_code()
+    #         cond = constraint.to_code()
 
-#         code = f"""
-# if(!({cond}))
-# {{
-#     PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}"); 
-#     return NULL;
-# }}
-# """
+    #         code = f"""
+    # if(!({cond}))
+    # {{
+    #     PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}");
+    #     return NULL;
+    # }}
+    # """
 
-#         self._constraint_checks.append(code)
+    #         self._constraint_checks.append(code)
 
     def call(self, kernel: Kernel, params: tuple[Parameter, ...]):
         param_list = ", ".join(p.name for p in params)
diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py
index c208ac219..a407bb75e 100644
--- a/src/pystencils/jit/gpu_cupy.py
+++ b/src/pystencils/jit/gpu_cupy.py
@@ -19,7 +19,7 @@ from ..codegen import (
     Parameter,
 )
 from ..codegen.properties import FieldShape, FieldStride, FieldBasePtr
-from ..types import PsStructType
+from ..types import PsStructType, PsPointerType
 
 from ..include import get_pystencils_include_path
 
@@ -160,8 +160,18 @@ class CupyKernelWrapper(KernelWrapper):
                 for prop in kparam.properties:
                     match prop:
                         case FieldBasePtr(field):
+
+                            elem_dtype: PsType
+
+                            from .. import DynamicType
+                            if isinstance(field.dtype, DynamicType):
+                                assert isinstance(kparam.dtype, PsPointerType)
+                                elem_dtype = kparam.dtype.base_type
+                            else:
+                                elem_dtype = field.dtype
+
                             arr = kwargs[field.name]
-                            if arr.dtype != field.dtype.numpy_dtype:
+                            if arr.dtype != elem_dtype.numpy_dtype:
                                 raise JitError(
                                     f"Data type mismatch at array argument {field.name}:"
                                     f"Expected {field.dtype}, got {arr.dtype}"
diff --git a/tests/frontend/test_field.py b/tests/frontend/test_field.py
index 6521e114f..56c3bdabb 100644
--- a/tests/frontend/test_field.py
+++ b/tests/frontend/test_field.py
@@ -61,13 +61,13 @@ def test_field_basic():
     neighbor = field_access.neighbor(coord_id=0, offset=-2)
     assert neighbor.offsets == (-1, 1)
     assert "_" in neighbor._latex("dummy")
-    assert f.dtype == DynamicType.NUMERIC_TYPE
+    assert f.dtype == create_type("float64")
 
     f = Field.create_fixed_size("f", (8, 8, 2, 2, 2), index_dimensions=3)
     assert f.center_vector == sp.Array(
         [[[f(i, j, k) for k in range(2)] for j in range(2)] for i in range(2)]
     )
-    assert f.dtype == DynamicType.NUMERIC_TYPE
+    assert f.dtype == create_type("float64")
 
     f = Field.create_generic("f", spatial_dimensions=5, index_dimensions=2)
     field_access = f[1, -1, 2, -3, 0](1, 0)
diff --git a/tests/runtime/test_datahandling.py b/tests/runtime/test_datahandling.py
index 29e639c88..c73ec829d 100644
--- a/tests/runtime/test_datahandling.py
+++ b/tests/runtime/test_datahandling.py
@@ -249,7 +249,7 @@ def test_add_arrays():
     dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=0, default_layout='numpy')
     x_, y_ = dh.add_arrays(field_description)
 
-    x, y = ps.fields(field_description + ': [3,4,5]')
+    x, y = ps.fields(field_description + ': float64[3,4,5]')
 
     assert x_ == x
     assert y_ == y
-- 
GitLab