diff --git a/docs/source/api/symbolic/sympyextensions.rst b/docs/source/api/symbolic/sympyextensions.rst
index e3d10fbdf67a1fc26fe1e339b0e642d86f1be51e..4190569d2ab838356a4f501e6811bb7e9202e666 100644
--- a/docs/source/api/symbolic/sympyextensions.rst
+++ b/docs/source/api/symbolic/sympyextensions.rst
@@ -71,7 +71,10 @@ Typed Expressions
 .. autoclass:: pystencils.DynamicType
     :members:
 
-.. autoclass:: pystencils.sympyextensions.CastFunc
+.. autoclass:: pystencils.sympyextensions.typed_sympy.TypeCast
+    :members:
+
+.. autoclass:: pystencils.sympyextensions.tcast
 
 
 Integer Operations
diff --git a/docs/source/contributing/dev-workflow.md b/docs/source/contributing/dev-workflow.md
index 2aee09ba2e78bd0041ec6d2d2860385514240ecf..fe8b70e7703385d45f7fd2d53822424b193c2592 100644
--- a/docs/source/contributing/dev-workflow.md
+++ b/docs/source/contributing/dev-workflow.md
@@ -118,10 +118,10 @@ mypy src/pystencils
 ::::
 
 :::{note}
-Type checking is currently restricted to the `codegen`, `jit`, `backend`, and `types` modules,
-since most code in the remaining modules is significantly older and is not comprehensively
-type-annotated. As more modules are updated with type annotations, this list will expand in the future.
-If you think a new module is ready to be type-checked, add an exception clause for it in the `mypy.ini` file.
+Type checking is currently restricted only to a few modules, which are listed in the `mypy.ini` file.
+Most code in the remaining modules is significantly older and is not comprehensively type-annotated.
+As more modules are updated with type annotations, this list will expand in the future.
+If you think a new module is ready to be type-checked, add an exception clause to `mypy.ini`.
 :::
 
 ## Running the Test Suite
diff --git a/docs/source/index.rst b/docs/source/index.rst
index cb455c8b4d1589353a7538c0e98b5eab864b4392..6dba50af184ee95c7378a2e923bd76a6d97883a2 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -82,6 +82,7 @@ Topics
   user_manual/symbolic_language
   user_manual/kernelcreation
   user_manual/gpu_kernels
+  user_manual/WorkingWithTypes
 
 .. toctree::
   :maxdepth: 1
diff --git a/docs/source/user_manual/WorkingWithTypes.md b/docs/source/user_manual/WorkingWithTypes.md
new file mode 100644
index 0000000000000000000000000000000000000000..e0f9283773cea55453ecdfe2377dc82b096e0741
--- /dev/null
+++ b/docs/source/user_manual/WorkingWithTypes.md
@@ -0,0 +1,164 @@
+---
+file_format: mystnb
+kernelspec:
+  display_name: Python 3 (ipykernel)
+  language: python
+  name: python3
+mystnb:
+  execution_mode: cache
+---
+
+# Working with Data Types
+
+This guide will demonstrate the various options that exist to customize the data types
+in generated kernels.
+Data types can be modified on different levels of granularity:
+Individual fields and symbols,
+single subexpressions,
+or the entire kernel.
+
+```{code-cell} ipython3
+:tags: [remove-cell]
+import pystencils as ps
+import sympy as sp
+```
+
+## Changing the Default Data Types
+
+The pystencils code generator defines two default data types:
+ - The default *numeric type*, which is applied to all numerical computations that are not
+   otherwise explicitly typed; the default is `float64`.
+ - The default *index type*, which is used for all loop and field index calculations; the default is `int64`.
+
+These can be modified by setting the
+{any}`default_dtype <CreateKernelConfig.default_dtype>` and
+{any}`index_type <CreateKernelConfig.index_dtype>`
+options of the code generator configuration:
+
+```{code-cell} ipython3
+cfg = ps.CreateKernelConfig()
+cfg.default_dtype = "float32"
+cfg.index_dtype = "int32"
+```
+
+Modifying these will change the way types for [untyped symbols](#untyped-symbols)
+and [dynamically typed expressions](#dynamic-typing) are computed.
+
+## Setting the Types of Fields and Symbols
+
+(untyped-symbols)=
+### Untyped Symbols
+
+Symbols used inside a kernel are most commonly created using
+{any}`sp.symbols <sympy.core.symbol.symbols>` or
+{any}`sp.Symbol <sympy.core.symbol.Symbol>`.
+These symbols are *untyped*; they will receive a type during code generation
+according to these rules:
+ - Free untyped symbols (i.e. symbols not defined by an assignment inside the kernel) receive the 
+   {any}`default data type <CreateKernelConfig.default_dtype>` specified in the code generator configuration.
+ - Bound untyped symbols (i.e. symbols that *are* defined in an assignment)
+   receive the data type that was computed for the right-hand side expression of their defining assignment.
+
+If you are working on kernels with homogenous data types, using untyped symbols will mostly be enough.
+
+### Explicitly Typed Symbols and Fields
+
+If you need more control over the data types in (parts of) your kernel,
+you will have to explicitly specify them.
+To set an explicit data type for a symbol, use the {any}`TypedSymbol` class of pystencils:
+
+```{code-cell} ipython3
+x_typed = ps.TypedSymbol("x", "uint32")
+x_typed, str(x_typed.dtype)
+```
+
+You can set a `TypedSymbol` to any data type provided by [the type system](#page_type_system),
+which will then be enforced by the code generator.
+
+The same holds for fields:
+When creating fields through the {any}`fields <pystencils.field.fields>` function,
+add the type to the descriptor string; for instance:
+
+```{code-cell} ipython3
+f, g = ps.fields("f(1), g(3): float32[3D]")
+str(f.dtype), str(g.dtype)
+```
+
+When using `Field.create_generic` or `Field.create_fixed_size`, on the other hand,
+you can set the data type via the `dtype` keyword argument.
+
+(dynamic-typing)=
+### Dynamically Typed Symbols and Fields
+
+Apart from explicitly setting data types,
+`TypedSymbol`s and fields can also receive a *dynamic data type* (see {any}`DynamicType`).
+There are two options:
+ - Symbols or fields annotated with {any}`DynamicType.NUMERIC_TYPE` will always receive
+   the {any}`default numeric type <CreateKernelConfig.default_dtype>` configured for the
+   code generator.
+   This is the default setting for fields
+   created through `fields`, `Field.create_generic` or `Field.create_fixed_size`.
+ - When annotated with {any}`DynamicType.INDEX_TYPE`, on the other hand, they will receive
+   the {any}`index data type <CreateKernelConfig.index_dtype>` configured for the kernel.
+
+Using dynamic typing, you can enforce symbols to receive either the standard numeric or
+index type without explicitly stating it, such that your kernel definition becomes
+independent from the code generator configuration.
+
+## Mixing Types Inside Expressions
+
+Pystencils enforces that all symbols, constants, and fields occuring inside an expression
+have the same data type.
+The code generator will never introduce implicit casts--if any type conflicts arise, it will terminate with an error.
+
+Still, there are cases where you want to combine subexpressions of different types;
+maybe you need to compute geometric information from loop counters or other integers,
+or you are doing mixed-precision numerical computations.
+In these cases, you might have to introduce explicit type casts when values move from one type context to another.
+ 
+ <!-- 2. Annotate expressions with a specific data type to ensure computations are performed in that type. 
+  TODO: See #97 (https://i10git.cs.fau.de/pycodegen/pystencils/-/issues/97)
+ -->
+
+(type_casts)=
+### Type Casts
+
+Type casts can be introduced into kernels using the {any}`tcast` symbolic function.
+It takes an expression and a data type, which is either an explicit type (see [the type system](#page_type_system))
+or a dynamic type ({any}`DynamicType`):
+
+```{code-cell} ipython3
+x, y = sp.symbols("x, y")
+expr1 = ps.tcast(x, "float32")
+expr2 = ps.tcast(3 + y, ps.DynamicType.INDEX_TYPE)
+
+str(expr1.dtype), str(expr2.dtype)
+```
+
+When a type cast occurs, pystencils will compute the type of its argument independently
+and then introduce a runtime cast to the target type.
+That target type must comply with the type computed for the outer expression,
+which the cast is embedded in.
+
+## Understanding the pystencils Type Inference System
+
+To correctly apply varying data types to pystencils kernels, it is important to understand
+how pystencils computes and propagates the data types of symbols and expressions.
+
+Type inference happens on the level of assignments.
+For each assignment $x := \mathrm{calc}(y_1, \dots, y_n)$,
+the system first attempts to compute a *unique* type for the right-hand side (RHS) $\mathrm{calc}(y_1, \dots, y_n)$.
+It searches for any subexpression inside the RHS for which a type is already known --
+these might be typed symbols
+(whose types are either set explicitly by the user,
+or have been determined from their defining assignment),
+field accesses,
+or explicitly typed expressions.
+It then attempts to apply that data type to the entire expression.
+If type conflicts occur, the process fails and the code generator raises an error.
+Otherwise, the resulting type is assigned to the left-hand side symbol $x$.
+
+:::{admonition} Developer's To Do
+It would be great to illustrate this using a GraphViz-plot of an AST,
+with nodes colored according to their data types
+:::
diff --git a/docs/source/user_manual/kernelcreation.md b/docs/source/user_manual/kernelcreation.md
index c85c8f99d3490602321c57f881b32b0127051c70..ad346473cc05dc746794ffc8f56f5bf21ffdec90 100644
--- a/docs/source/user_manual/kernelcreation.md
+++ b/docs/source/user_manual/kernelcreation.md
@@ -138,7 +138,7 @@ This happens roughly according to the following rules:
 We can observe this behavior by setting up a kernel including several fields with different data types:
 
 ```{code-cell} ipython3
-from pystencils.sympyextensions import CastFunc
+from pystencils.sympyextensions import tcast
 
 f = ps.fields("f: float32[2D]")
 g = ps.fields("g: float16[2D]")
diff --git a/mypy.ini b/mypy.ini
index cc23a503a2da6c9849d3a41e82fe8ceb8de13b43..c8a7195e2e28bffbeb79e1e552822cea4e8dd041 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -17,6 +17,12 @@ ignore_errors = False
 [mypy-pystencils.jit.*]
 ignore_errors = False
 
+[mypy-pystencils.field]
+ignore_errors = False
+
+[mypy-pystencils.sympyextensions.typed_sympy]
+ignore_errors = False
+
 [mypy-setuptools.*]
 ignore_missing_imports=true
 
diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py
index a23ce185d1a4f6c9cd9a17fccf315462eddf287f..07283d5294bc08c8e68e6a40af0f956b36a0129a 100644
--- a/src/pystencils/__init__.py
+++ b/src/pystencils/__init__.py
@@ -32,7 +32,7 @@ from .spatial_coordinates import (
 from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil
 from .simp import AssignmentCollection
 from .sympyextensions.typed_sympy import TypedSymbol, DynamicType
-from .sympyextensions import SymbolCreator
+from .sympyextensions import SymbolCreator, tcast
 from .datahandling import create_data_handling
 
 __all__ = [
@@ -77,6 +77,7 @@ __all__ = [
     "x_staggered_vector",
     "fd",
     "stencil",
+    "tcast",
 ]
 
 from . import _version
diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py
index 8f5931c6494bbf1eca950e38df0053e79af3e81b..68da893ff6204c73d61dcebddb1da602f37520c7 100644
--- a/src/pystencils/backend/kernelcreation/context.py
+++ b/src/pystencils/backend/kernelcreation/context.py
@@ -93,6 +93,16 @@ class KernelCreationContext:
     def index_dtype(self) -> PsIntegerType:
         """Data type used by default for index expressions"""
         return self._index_dtype
+    
+    def resolve_dynamic_type(self, dtype: DynamicType | PsType) -> PsType:
+        """Selects the appropriate data type for `DynamicType` instances, and returns all other types as they are."""
+        match dtype:
+            case DynamicType.NUMERIC_TYPE:
+                return self._default_dtype
+            case DynamicType.INDEX_TYPE:
+                return self._index_dtype
+            case _:
+                return dtype
 
     @property
     def metadata(self) -> dict[str, Any]:
@@ -339,6 +349,8 @@ class KernelCreationContext:
             if isinstance(s, TypedSymbol)
         )
 
+        entry_type = self.resolve_dynamic_type(field.dtype)
+
         if len(idx_types) > 1:
             raise KernelConstraintsError(
                 f"Multiple incompatible types found in index symbols of field {field}: "
@@ -375,10 +387,10 @@ class KernelCreationContext:
 
         base_ptr = self.get_symbol(
             DEFAULTS.field_pointer_name(field.name),
-            PsPointerType(field.dtype, restrict=True),
+            PsPointerType(entry_type, restrict=True),
         )
 
-        return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides)
+        return PsBuffer(field.name, entry_type, base_ptr, buf_shape, buf_strides)
 
     def _create_buffer_field_buffer(self, field: Field) -> PsBuffer:
         if field.spatial_dimensions != 1:
@@ -418,10 +430,11 @@ class KernelCreationContext:
             ]
 
         buf_strides = [PsConstant(num_entries, idx_type), PsConstant(1, idx_type)]
+        buf_dtype = self.resolve_dynamic_type(field.dtype)
 
         base_ptr = self.get_symbol(
             DEFAULTS.field_pointer_name(field.name),
-            PsPointerType(field.dtype, restrict=True),
+            PsPointerType(buf_dtype, restrict=True),
         )
 
-        return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides)
+        return PsBuffer(field.name, buf_dtype, base_ptr, buf_shape, buf_strides)
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index c38fcbc9730e046b2d51762c5846c1542f0cbe74..4fd09f879dd8d98903753c8709543e0bcc3fd3e1 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -13,7 +13,7 @@ from ...sympyextensions import (
     integer_functions,
     ConditionalFieldAccess,
 )
-from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType
+from ...sympyextensions.typed_sympy import TypedSymbol, TypeCast, DynamicType
 from ...sympyextensions.pointers import AddressOf, mem_acc
 from ...field import Field, FieldType
 
@@ -270,14 +270,7 @@ class FreezeExpressions:
         return num / denom
 
     def map_TypedSymbol(self, expr: TypedSymbol):
-        dtype = expr.dtype
-
-        match dtype:
-            case DynamicType.NUMERIC_TYPE:
-                dtype = self._ctx.default_dtype
-            case DynamicType.INDEX_TYPE:
-                dtype = self._ctx.index_dtype
-
+        dtype = self._ctx.resolve_dynamic_type(expr.dtype)
         symb = self._ctx.get_symbol(expr.name, dtype)
         return PsSymbolExpr(symb)
 
@@ -490,7 +483,7 @@ class FreezeExpressions:
             ]
         return cast(PsCall, args[0])
 
-    def map_CastFunc(self, cast_expr: CastFunc) -> PsCast | PsConstantExpr:
+    def map_TypeCast(self, cast_expr: TypeCast) -> PsCast | PsConstantExpr:
         dtype: PsType
         match cast_expr.dtype:
             case DynamicType.NUMERIC_TYPE:
diff --git a/src/pystencils/field.py b/src/pystencils/field.py
index 1a3a13b73125bd07e6161a1694a6bf03dc2ba506..246232efde7a6b432598f614492725e2ea063cff 100644
--- a/src/pystencils/field.py
+++ b/src/pystencils/field.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import functools
 import hashlib
 import operator
@@ -5,23 +7,28 @@ import pickle
 import re
 from enum import Enum
 from itertools import chain
-from typing import List, Optional, Sequence, Set, Tuple, Union
+from typing import List, Optional, Sequence, Set, Tuple
+from warnings import warn
 
 import numpy as np
 import sympy as sp
 from sympy.core.cache import cacheit
 
 from .defaults import DEFAULTS
-from pystencils.alignedarray import aligned_empty
-from pystencils.spatial_coordinates import x_staggered_vector, x_vector
-from pystencils.stencil import direction_string_to_offset, inverse_direction, offset_to_direction_string
-from pystencils.types import PsType, PsStructType, create_type
-from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType
-from pystencils.sympyextensions import is_integer_sequence
-from pystencils.types import UserTypeSpec
+from .alignedarray import aligned_empty
+from .spatial_coordinates import x_staggered_vector, x_vector
+from .stencil import (
+    direction_string_to_offset,
+    inverse_direction,
+    offset_to_direction_string,
+)
+from .types import PsType, PsStructType, create_type
+from .sympyextensions.typed_sympy import TypedSymbol, DynamicType
+from .sympyextensions import is_integer_sequence
+from .types import UserTypeSpec
 
 
-__all__ = ['Field', 'fields', 'FieldType', 'Field']
+__all__ = ["Field", "fields", "FieldType", "Field"]
 
 
 class FieldType(Enum):
@@ -63,7 +70,10 @@ class FieldType(Enum):
     @staticmethod
     def is_staggered(field):
         assert isinstance(field, Field)
-        return field.field_type == FieldType.STAGGERED or field.field_type == FieldType.STAGGERED_FLUX
+        return (
+            field.field_type == FieldType.STAGGERED
+            or field.field_type == FieldType.STAGGERED_FLUX
+        )
 
     @staticmethod
     def is_staggered_flux(field):
@@ -123,23 +133,34 @@ class Field:
         >>> assignments = [Assignment(dst[0,0](i), src[-offset](i)) for i, offset in enumerate(stencil)];
 
     Args:
-        field_name: something
-        field_type: something
-        dtype: something
-        layout: something
-        shape: something
-        strides: something
+        field_name: The field's name
+        field_type: The kind of the field
+        dtype: Data type of the field's entries
+        layout: Linearization order of the field's spatial dimensions
+        shape: Total shape (spatial and index) of the field
+        strides: Linearization strides of the field
     """
 
     @staticmethod
-    def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec = np.float64, index_dimensions=0, 
-                       layout='numpy', index_shape=None, field_type=FieldType.GENERIC) -> 'Field':
+    def create_generic(
+        field_name,
+        spatial_dimensions,
+        dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE,
+        index_dimensions=0,
+        layout="numpy",
+        index_shape=None,
+        field_type=FieldType.GENERIC,
+    ) -> "Field":
         """
-        Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes
+        Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes.
+
+        **Field Element Type** By default, the data type of the field entries is left undetermined until
+        code generation, at which point it is set to the default numerical type of the kernel.
+        You can specify a concrete type using the `dtype` parameter.
 
         Args:
             field_name: symbolic name for the field
-            dtype: numpy data type of the array the kernel is called with later
+            dtype: Data type of the field entries
             spatial_dimensions: see documentation of Field
             index_dimensions: see documentation of Field
             layout: tuple specifying the loop ordering of the spatial dimensions e.g. (2, 1, 0 ) means that
@@ -159,41 +180,60 @@ class Field:
             layout = spatial_layout_string_to_tuple(layout, dim=spatial_dimensions)
 
         total_dimensions = spatial_dimensions + index_dimensions
+        shape: tuple[TypedSymbol | int, ...]
+
         if index_shape is None or len(index_shape) == 0:
-            shape = tuple([
-                TypedSymbol(DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE) 
-                for i in range(total_dimensions)
-            ])
+            shape = tuple(
+                [
+                    TypedSymbol(
+                        DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE
+                    )
+                    for i in range(total_dimensions)
+                ]
+            )
         else:
             shape = tuple(
                 [
-                    TypedSymbol(DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE)
+                    TypedSymbol(
+                        DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE
+                    )
                     for i in range(spatial_dimensions)
-                ] + list(index_shape)
+                ]
+                + list(index_shape)
             )
 
-        strides = tuple([
-            TypedSymbol(DEFAULTS.field_stride_name(field_name, i), DynamicType.INDEX_TYPE) 
-            for i in range(total_dimensions)
-        ])
+        strides: tuple[TypedSymbol | int, ...] = tuple(
+            [
+                TypedSymbol(
+                    DEFAULTS.field_stride_name(field_name, i), DynamicType.INDEX_TYPE
+                )
+                for i in range(total_dimensions)
+            ]
+        )
 
-        dtype = create_type(dtype)
-        np_data_type = dtype.numpy_dtype
-        assert np_data_type is not None
-        
-        if np_data_type.fields is not None:
+        if not isinstance(dtype, DynamicType):
+            dtype = create_type(dtype)
+
+        if isinstance(dtype, PsStructType):
             if index_dimensions != 0:
-                raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
+                raise ValueError(
+                    "Structured arrays/fields are not allowed to have an index dimension"
+                )
             shape += (1,)
             strides += (1,)
+
         if field_type == FieldType.STAGGERED and index_dimensions == 0:
             raise ValueError("A staggered field needs at least one index dimension")
 
         return Field(field_name, field_type, dtype, layout, shape, strides)
 
     @staticmethod
-    def create_from_numpy_array(field_name: str, array: np.ndarray, index_dimensions: int = 0,
-                                field_type=FieldType.GENERIC) -> 'Field':
+    def create_from_numpy_array(
+        field_name: str,
+        array: np.ndarray,
+        index_dimensions: int = 0,
+        field_type=FieldType.GENERIC,
+    ) -> Field:
         """Creates a field based on the layout, data type, and shape of a given numpy array.
 
         Kernels created for these kind of fields can only be called with arrays of the same layout, shape and type.
@@ -206,7 +246,9 @@ class Field:
         """
         spatial_dimensions = len(array.shape) - index_dimensions
         if spatial_dimensions < 1:
-            raise ValueError("Too many index dimensions. At least one spatial dimension required")
+            raise ValueError(
+                "Too many index dimensions. At least one spatial dimension required"
+            )
 
         full_layout = get_layout_of_array(array)
         spatial_layout = tuple([i for i in full_layout if i < spatial_dimensions])
@@ -218,21 +260,31 @@ class Field:
         numpy_dtype = np.dtype(array.dtype)
         if numpy_dtype.fields is not None:
             if index_dimensions != 0:
-                raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
+                raise ValueError(
+                    "Structured arrays/fields are not allowed to have an index dimension"
+                )
             shape += (1,)
             strides += (1,)
         if field_type == FieldType.STAGGERED and index_dimensions == 0:
             raise ValueError("A staggered field needs at least one index dimension")
 
-        return Field(field_name, field_type, array.dtype, spatial_layout, shape, strides)
+        return Field(
+            field_name, field_type, array.dtype, spatial_layout, shape, strides
+        )
 
     @staticmethod
-    def create_fixed_size(field_name: str, shape: Tuple[int, ...], index_dimensions: int = 0,
-                          dtype: UserTypeSpec = np.float64, layout: str = 'numpy',
-                          strides: Optional[Sequence[int]] = None,
-                          field_type=FieldType.GENERIC) -> 'Field':
+    def create_fixed_size(
+        field_name: str,
+        shape: tuple[int, ...],
+        index_dimensions: int = 0,
+        dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE,
+        layout: str | tuple[int, ...] = "numpy",
+        memory_strides: None | Sequence[int] = None,
+        strides: Optional[Sequence[int]] = None,
+        field_type=FieldType.GENERIC,
+    ) -> Field:
         """
-        Creates a field with fixed sizes i.e. can be called only with arrays of the same size and layout
+        Creates a field with fixed sizes i.e. can be called only with arrays of the same size and layout.
 
         Args:
             field_name: symbolic name for the field
@@ -240,54 +292,90 @@ class Field:
             index_dimensions: how many of the trailing dimensions are interpreted as index (as opposed to spatial)
             dtype: numpy data type of the array the kernel is called with later
             layout: full layout of array, not only spatial dimensions
-            strides: strides in bytes or None to automatically compute them from shape (assuming no padding)
+            memory_strides: Linearization strides for each dimension;
+                i.e. the number of elements to skip to get from one index to the next in the respective dimension.
             field_type: kind of field
         """
+        if strides is not None:
+            warn(
+                "The `strides` parameter to `Field.create_fixed_size` is deprecated "
+                "and will be removed in pystencils 2.1. "
+                "Use `memory_strides` instead; "
+                "beware that `memory_strides` takes the number of *elements* to skip, "
+                "instead of the number of bytes.",
+                FutureWarning
+            )
+
+            if memory_strides is not None:
+                raise ValueError("Cannot specify `memory_strides` and deprecated parameter `strides` at the same time.")
+            
+            if isinstance(dtype, DynamicType):
+                raise ValueError("Cannot specify the deprecated parameter `strides` together with a `DynamicType`. "
+                                 "Set `memory_strides` instead.")
+            
+            np_type = create_type(dtype).numpy_dtype
+            assert np_type is not None
+            memory_strides = tuple([s // np_type.itemsize for s in strides])
+
         spatial_dimensions = len(shape) - index_dimensions
         assert spatial_dimensions >= 1
 
         if isinstance(layout, str):
-            layout = layout_string_to_tuple(layout, spatial_dimensions + index_dimensions)
+            layout = layout_string_to_tuple(
+                layout, spatial_dimensions + index_dimensions
+            )
+
+        if not isinstance(dtype, DynamicType):
+            dtype = create_type(dtype)
+
+        shape_tuple = tuple(int(s) for s in shape)
+        strides_tuple: tuple[int, ...]
 
-        shape = tuple(int(s) for s in shape)
         if strides is None:
-            strides = compute_strides(shape, layout)
+            strides_tuple = compute_strides(shape_tuple, layout)
         else:
-            assert len(strides) == len(shape)
-            strides = tuple([s // np.dtype(dtype).itemsize for s in strides])
+            assert len(strides) == len(shape_tuple)
+            strides_tuple = tuple(strides)
 
-        dtype = create_type(dtype)
-        numpy_dtype = dtype.numpy_dtype
-        assert numpy_dtype is not None
-
-        if numpy_dtype.fields is not None:
+        if isinstance(dtype, PsStructType):
             if index_dimensions != 0:
-                raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
-            shape += (1,)
-            strides += (1,)
+                raise ValueError(
+                    "Structured arrays/fields are not allowed to have an index dimension"
+                )
+            shape_tuple += (1,)
+            strides_tuple += (1,)
         if field_type == FieldType.STAGGERED and index_dimensions == 0:
             raise ValueError("A staggered field needs at least one index dimension")
 
         spatial_layout = list(layout)
         for i in range(spatial_dimensions, len(layout)):
             spatial_layout.remove(i)
-        return Field(field_name, field_type, dtype, tuple(spatial_layout), shape, strides)
+        return Field(
+            field_name,
+            field_type,
+            dtype,
+            tuple(spatial_layout),
+            shape_tuple,
+            strides_tuple,
+        )
 
     def __init__(
         self,
         field_name: str,
         field_type: FieldType,
-        dtype: UserTypeSpec,
+        dtype: UserTypeSpec | DynamicType,
         layout: tuple[int, ...],
         shape,
-        strides
+        strides,
     ):
         """Do not use directly. Use static create* methods"""
         self._field_name = field_name
         assert isinstance(field_type, FieldType)
         assert len(shape) == len(strides)
         self.field_type = field_type
-        self._dtype = create_type(dtype)
+        self._dtype: PsType | DynamicType = (
+            create_type(dtype) if not isinstance(dtype, DynamicType) else dtype
+        )
         self._layout = normalize_layout(layout)
         self.shape = shape
         self.strides = strides
@@ -299,9 +387,23 @@ class Field:
 
     def new_field_with_different_name(self, new_name):
         if self.has_fixed_shape:
-            return Field(new_name, self.field_type, self._dtype, self._layout, self.shape, self.strides)
+            return Field(
+                new_name,
+                self.field_type,
+                self._dtype,
+                self._layout,
+                self.shape,
+                self.strides,
+            )
         else:
-            return Field(new_name, self.field_type, self.dtype, self.layout, self.shape, self.strides)
+            return Field(
+                new_name,
+                self.field_type,
+                self.dtype,
+                self.layout,
+                self.shape,
+                self.strides,
+            )
 
     @property
     def spatial_dimensions(self) -> int:
@@ -328,7 +430,7 @@ class Field:
 
     @property
     def spatial_shape(self) -> Tuple[int, ...]:
-        return self.shape[:self.spatial_dimensions]
+        return self.shape[: self.spatial_dimensions]
 
     @property
     def has_fixed_shape(self):
@@ -344,31 +446,34 @@ class Field:
 
     @property
     def spatial_strides(self):
-        return self.strides[:self.spatial_dimensions]
+        return self.strides[: self.spatial_dimensions]
 
     @property
     def index_strides(self):
         return self.strides[self.spatial_dimensions:]
 
     @property
-    def dtype(self) -> PsType:
+    def dtype(self) -> PsType | DynamicType:
         return self._dtype
 
     @property
-    def itemsize(self):
-        return self.dtype.itemsize
+    def itemsize(self) -> int | None:
+        if isinstance(self.dtype, PsType):
+            return self.dtype.itemsize
+        else:
+            return None
 
     def __repr__(self):
         if any(isinstance(s, sp.Symbol) for s in self.spatial_shape):
-            spatial_shape_str = f'{self.spatial_dimensions}d'
+            spatial_shape_str = f"{self.spatial_dimensions}d"
         else:
-            spatial_shape_str = ','.join(str(i) for i in self.spatial_shape)
-        index_shape_str = ','.join(str(i) for i in self.index_shape)
+            spatial_shape_str = ",".join(str(i) for i in self.spatial_shape)
+        index_shape_str = ",".join(str(i) for i in self.index_shape)
 
         if self.index_shape:
-            return f'{self._field_name}({index_shape_str}): {self.dtype}[{spatial_shape_str}]'
+            return f"{self._field_name}({index_shape_str}): {self.dtype}[{spatial_shape_str}]"
         else:
-            return f'{self._field_name}: {self.dtype}[{spatial_shape_str}]'
+            return f"{self._field_name}: {self.dtype}[{spatial_shape_str}]"
 
     def __str__(self):
         return self.name
@@ -389,12 +494,26 @@ class Field:
         elif len(index_shape) == 1:
             return sp.Matrix([self(i) for i in range(index_shape[0])])
         elif len(index_shape) == 2:
-            return sp.Matrix([[self(i, j) for j in range(index_shape[1])] for i in range(index_shape[0])])
+            return sp.Matrix(
+                [
+                    [self(i, j) for j in range(index_shape[1])]
+                    for i in range(index_shape[0])
+                ]
+            )
         elif len(index_shape) == 3:
-            return sp.Array([[[self(i, j, k) for k in range(index_shape[2])]
-                              for j in range(index_shape[1])] for i in range(index_shape[0])])
+            return sp.Array(
+                [
+                    [
+                        [self(i, j, k) for k in range(index_shape[2])]
+                        for j in range(index_shape[1])
+                    ]
+                    for i in range(index_shape[0])
+                ]
+            )
         else:
-            raise NotImplementedError("center_vector is not implemented for more than 3 index dimensions")
+            raise NotImplementedError(
+                "center_vector is not implemented for more than 3 index dimensions"
+            )
 
     @property
     def center(self):
@@ -410,12 +529,20 @@ class Field:
         if self.index_dimensions == 0:
             return sp.Matrix([self.__getitem__(offset)])
         elif self.index_dimensions == 1:
-            return sp.Matrix([self.__getitem__(offset)(i) for i in range(self.index_shape[0])])
+            return sp.Matrix(
+                [self.__getitem__(offset)(i) for i in range(self.index_shape[0])]
+            )
         elif self.index_dimensions == 2:
-            return sp.Matrix([[self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])]
-                              for i in range(self.index_shape[0])])
+            return sp.Matrix(
+                [
+                    [self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])]
+                    for i in range(self.index_shape[0])
+                ]
+            )
         else:
-            raise NotImplementedError("neighbor_vector is not implemented for more than 2 index dimensions")
+            raise NotImplementedError(
+                "neighbor_vector is not implemented for more than 2 index dimensions"
+            )
 
     def __getitem__(self, offset):
         if type(offset) is np.ndarray:
@@ -425,7 +552,9 @@ class Field:
         if type(offset) is not tuple:
             offset = (offset,)
         if len(offset) != self.spatial_dimensions:
-            raise ValueError(f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}")
+            raise ValueError(
+                f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}"
+            )
         return Field.Access(self, offset)
 
     def absolute_access(self, offset, index):
@@ -448,7 +577,9 @@ class Field:
             offset = tuple(direction_string_to_offset(offset, self.spatial_dimensions))
             offset = tuple([o * sp.Rational(1, 2) for o in offset])
         if len(offset) != self.spatial_dimensions:
-            raise ValueError(f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}")
+            raise ValueError(
+                f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}"
+            )
 
         prefactor = 1
         neighbor_vec = [0] * len(offset)
@@ -462,25 +593,33 @@ class Field:
             if FieldType.is_staggered_flux(self):
                 prefactor = -1
         if neighbor not in self.staggered_stencil:
-            raise ValueError(f"{offset_orig} is not a valid neighbor for the {self.staggered_stencil_name} stencil")
+            raise ValueError(
+                f"{offset_orig} is not a valid neighbor for the {self.staggered_stencil_name} stencil"
+            )
 
         offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec))
 
         idx = self.staggered_stencil.index(neighbor)
 
-        if self.index_dimensions == 1:  # this field stores a scalar value at each staggered position
+        if (
+            self.index_dimensions == 1
+        ):  # this field stores a scalar value at each staggered position
             if index is not None:
                 raise ValueError("Cannot specify an index for a scalar staggered field")
             return prefactor * Field.Access(self, offset, (idx,))
         else:  # this field stores a vector or tensor at each staggered position
             if index is None:
-                raise ValueError(f"Wrong number of indices: Got 0, expected {self.index_dimensions - 1}")
+                raise ValueError(
+                    f"Wrong number of indices: Got 0, expected {self.index_dimensions - 1}"
+                )
             if type(index) is np.ndarray:
                 index = tuple(index)
             if type(index) is not tuple:
                 index = (index,)
             if self.index_dimensions != len(index) + 1:
-                raise ValueError(f"Wrong number of indices: Got {len(index)}, expected {self.index_dimensions - 1}")
+                raise ValueError(
+                    f"Wrong number of indices: Got {len(index)}, expected {self.index_dimensions - 1}"
+                )
 
             return prefactor * Field.Access(self, offset, (idx, *index))
 
@@ -491,30 +630,54 @@ class Field:
         if self.index_dimensions == 1:
             return sp.Matrix([self.staggered_access(offset)])
         elif self.index_dimensions == 2:
-            return sp.Matrix([self.staggered_access(offset, i) for i in range(self.index_shape[1])])
+            return sp.Matrix(
+                [self.staggered_access(offset, i) for i in range(self.index_shape[1])]
+            )
         elif self.index_dimensions == 3:
-            return sp.Matrix([[self.staggered_access(offset, (i, k)) for k in range(self.index_shape[2])]
-                              for i in range(self.index_shape[1])])
+            return sp.Matrix(
+                [
+                    [
+                        self.staggered_access(offset, (i, k))
+                        for k in range(self.index_shape[2])
+                    ]
+                    for i in range(self.index_shape[1])
+                ]
+            )
         else:
-            raise NotImplementedError("staggered_vector_access is not implemented for more than 3 index dimensions")
+            raise NotImplementedError(
+                "staggered_vector_access is not implemented for more than 3 index dimensions"
+            )
 
     @property
     def staggered_stencil(self):
         assert FieldType.is_staggered(self)
         stencils = {
-            2: {
-                2: ["W", "S"],  # D2Q5
-                4: ["W", "S", "SW", "NW"]  # D2Q9
-            },
+            2: {2: ["W", "S"], 4: ["W", "S", "SW", "NW"]},  # D2Q5  # D2Q9
             3: {
                 3: ["W", "S", "B"],  # D3Q7
                 7: ["W", "S", "B", "BSW", "TSW", "BNW", "TNW"],  # D3Q15
                 9: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS"],  # D3Q19
-                13: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS", "BSW", "TSW", "BNW", "TNW"]  # D3Q27
-            }
+                13: [
+                    "W",
+                    "S",
+                    "B",
+                    "SW",
+                    "NW",
+                    "BW",
+                    "TW",
+                    "BS",
+                    "TS",
+                    "BSW",
+                    "TSW",
+                    "BNW",
+                    "TNW",
+                ],  # D3Q27
+            },
         }
         if not self.index_shape[0] in stencils[self.spatial_dimensions]:
-            raise ValueError(f"No known stencil has {self.index_shape[0]} staggered points")
+            raise ValueError(
+                f"No known stencil has {self.index_shape[0]} staggered points"
+            )
         return stencils[self.spatial_dimensions][self.index_shape[0]]
 
     @property
@@ -527,13 +690,15 @@ class Field:
         return Field.Access(self, center)(*args, **kwargs)
 
     def hashable_contents(self):
-        return (self._layout,
-                self.shape,
-                self.strides,
-                self.field_type,
-                self._field_name,
-                self.latex_name,
-                self._dtype)
+        return (
+            self._layout,
+            self.shape,
+            self.strides,
+            self.field_type,
+            self._field_name,
+            self.latex_name,
+            self._dtype,
+        )
 
     def __hash__(self):
         return hash(self.hashable_contents())
@@ -545,36 +710,53 @@ class Field:
 
     @property
     def physical_coordinates(self):
-        if hasattr(self.coordinate_transform, '__call__'):
-            return self.coordinate_transform(self.coordinate_origin + x_vector(self.spatial_dimensions))
+        if hasattr(self.coordinate_transform, "__call__"):
+            return self.coordinate_transform(
+                self.coordinate_origin + x_vector(self.spatial_dimensions)
+            )
         else:
-            return self.coordinate_transform @ (self.coordinate_origin + x_vector(self.spatial_dimensions))
+            return self.coordinate_transform @ (
+                self.coordinate_origin + x_vector(self.spatial_dimensions)
+            )
 
     @property
     def physical_coordinates_staggered(self):
-        return self.coordinate_transform @ \
-            (self.coordinate_origin + x_staggered_vector(self.spatial_dimensions))
+        return self.coordinate_transform @ (
+            self.coordinate_origin + x_staggered_vector(self.spatial_dimensions)
+        )
 
     def index_to_physical(self, index_coordinates: sp.Matrix, staggered=False):
         if staggered:
-            index_coordinates = sp.Matrix([0.5] * len(self.coordinate_origin)) + index_coordinates
-        if hasattr(self.coordinate_transform, '__call__'):
+            index_coordinates = (
+                sp.Matrix([0.5] * len(self.coordinate_origin)) + index_coordinates
+            )
+        if hasattr(self.coordinate_transform, "__call__"):
             return self.coordinate_transform(self.coordinate_origin + index_coordinates)
         else:
-            return self.coordinate_transform @ (self.coordinate_origin + index_coordinates)
+            return self.coordinate_transform @ (
+                self.coordinate_origin + index_coordinates
+            )
 
     def physical_to_index(self, physical_coordinates: sp.Matrix, staggered=False):
-        if hasattr(self.coordinate_transform, '__call__'):
-            if hasattr(self.coordinate_transform, 'inv'):
-                return self.coordinate_transform.inv()(physical_coordinates) - self.coordinate_origin
+        if hasattr(self.coordinate_transform, "__call__"):
+            if hasattr(self.coordinate_transform, "inv"):
+                return (
+                    self.coordinate_transform.inv()(physical_coordinates)
+                    - self.coordinate_origin
+                )
             else:
-                idx = sp.Matrix(sp.symbols(f'index_coordinates:{self.ndim}', real=True))
+                idx = sp.Matrix(sp.symbols(f"index_coordinates:{self.ndim}", real=True))
                 rtn = sp.solve(self.index_to_physical(idx) - physical_coordinates, idx)
-                assert rtn, f'Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}'
+                assert (
+                    rtn
+                ), f"Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}"
                 return rtn
 
         else:
-            rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin
+            rtn = (
+                self.coordinate_transform.inv() @ physical_coordinates
+                - self.coordinate_origin
+            )
         if staggered:
             rtn = sp.Matrix([i - 0.5 for i in rtn])
 
@@ -603,18 +785,40 @@ class Field:
             >>> central_y_component.at_index(0)  # change component
             v_C^0
         """
+
         _iterable = False  # see https://i10git.cs.fau.de/pycodegen/pystencils/-/merge_requests/166#note_10680
 
         __match_args__ = ("field", "offsets", "index")
 
+        #   for the type checker
+        _field: Field
+        _offsets: tuple[int | sp.Basic, ...]
+        _offsetName: str
+        _superscript: None | str
+        _index: tuple[int | sp.Basic, ...] | str
+        _indirect_addressing_fields: set[Field]
+        _is_absolute_access: bool
+
         def __new__(cls, name, *args, **kwargs):
             obj = Field.Access.__xnew_cached_(cls, name, *args, **kwargs)
             return obj
 
-        def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None, is_absolute_access=False, dtype=None):
+        def __new_stage2__(  # type: ignore
+            self,
+            field: Field,
+            offsets: tuple[int, ...] = (0, 0, 0),
+            idx: None | tuple[int, ...] | str = None,
+            is_absolute_access: bool = False,
+            dtype: PsType | None = None,
+        ):
             field_name = field.name
             offsets_and_index = (*offsets, *idx) if idx is not None else offsets
-            constant_offsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsets_and_index])
+            constant_offsets = not any(
+                [
+                    isinstance(o, sp.Basic) and not o.is_Integer
+                    for o in offsets_and_index
+                ]
+            )
 
             if not idx:
                 idx = tuple([0] * field.index_dimensions)
@@ -628,31 +832,36 @@ class Field:
                 else:
                     idx_str = ",".join([str(e) for e in idx])
                     superscript = idx_str
-                if field.has_fixed_index_shape and not isinstance(field.dtype, PsStructType):
+                if field.has_fixed_index_shape and not isinstance(
+                    field.dtype, PsStructType
+                ):
                     for i, bound in zip(idx, field.index_shape):
                         if i >= bound:
                             raise ValueError("Field index out of bounds")
             else:
-                offset_name = hashlib.md5(pickle.dumps(offsets_and_index)).hexdigest()[:12]
+                offset_name = hashlib.md5(pickle.dumps(offsets_and_index)).hexdigest()[
+                    :12
+                ]
                 superscript = None
 
             symbol_name = f"{field_name}_{offset_name}"
             if superscript is not None:
                 symbol_name += "^" + superscript
 
+            obj: Field.Access
             if dtype is not None:
                 obj = super(Field.Access, self).__xnew__(self, symbol_name, dtype)
             else:
                 obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype)
 
             obj._field = field
-            obj._offsets = []
+            _offsets: list[sp.Basic | int] = []
             for o in offsets:
                 if isinstance(o, sp.Basic):
-                    obj._offsets.append(o)
+                    _offsets.append(o)
                 else:
-                    obj._offsets.append(int(o))
-            obj._offsets = tuple(sp.sympify(obj._offsets))
+                    _offsets.append(int(o))
+            obj._offsets = tuple(sp.sympify(_offsets))
             obj._offsetName = offset_name
             obj._superscript = superscript
             obj._index = idx
@@ -660,19 +869,33 @@ class Field:
             obj._indirect_addressing_fields = set()
             for e in chain(obj._offsets, obj._index):
                 if isinstance(e, sp.Basic):
-                    obj._indirect_addressing_fields.update(a.field for a in e.atoms(Field.Access))
+                    obj._indirect_addressing_fields.update(
+                        a.field for a in e.atoms(Field.Access)
+                    )
 
             obj._is_absolute_access = is_absolute_access
             return obj
 
         def __getnewargs__(self):
-            return self.field, self.offsets, self.index, self.is_absolute_access, self.dtype
+            return (
+                self.field,
+                self.offsets,
+                self.index,
+                self.is_absolute_access,
+                self.dtype,
+            )
 
         def __getnewargs_ex__(self):
-            return (self.field, self.offsets, self.index, self.is_absolute_access, self.dtype), {}
+            return (
+                self.field,
+                self.offsets,
+                self.index,
+                self.is_absolute_access,
+                self.dtype,
+            ), {}
 
         # noinspection SpellCheckingInspection
-        __xnew__ = staticmethod(__new_stage2__)
+        __xnew__ = staticmethod(__new_stage2__)  # type: ignore
         # noinspection SpellCheckingInspection
         __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
 
@@ -686,22 +909,34 @@ class Field:
                 idx = ()
 
             if len(idx) != self.field.index_dimensions:
-                raise ValueError(f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}")
+                raise ValueError(
+                    f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}"
+                )
             if len(idx) == 1 and isinstance(idx[0], str):
                 struct_type = self.field.dtype
                 assert isinstance(struct_type, PsStructType)
                 dtype = struct_type.get_member(idx[0]).dtype
-                return Field.Access(self.field, self._offsets, idx,
-                                    is_absolute_access=self.is_absolute_access, dtype=dtype)
+                return Field.Access(
+                    self.field,
+                    self._offsets,
+                    idx,
+                    is_absolute_access=self.is_absolute_access,
+                    dtype=dtype,
+                )
             else:
-                return Field.Access(self.field, self._offsets, idx,
-                                    is_absolute_access=self.is_absolute_access, dtype=self.dtype)
+                return Field.Access(
+                    self.field,
+                    self._offsets,
+                    idx,
+                    is_absolute_access=self.is_absolute_access,
+                    dtype=self.dtype,
+                )
 
         def __getitem__(self, *idx):
             return self.__call__(*idx)
 
         @property
-        def field(self) -> 'Field':
+        def field(self) -> "Field":
             """Field that the Access points to"""
             return self._field
 
@@ -713,7 +948,7 @@ class Field:
         @property
         def required_ghost_layers(self) -> int:
             """Largest spatial distance that is accessed."""
-            return int(np.max(np.abs(self._offsets)))
+            return int(np.max(np.abs(self._offsets)))  # type: ignore
 
         @property
         def nr_of_coordinates(self):
@@ -735,7 +970,7 @@ class Field:
             """Value of index coordinates as tuple."""
             return self._index
 
-        def neighbor(self, coord_id: int, offset: int) -> 'Field.Access':
+        def neighbor(self, coord_id: int, offset: int) -> "Field.Access":
             """Returns a new Access with changed spatial coordinates.
 
             Args:
@@ -749,10 +984,15 @@ class Field:
             """
             offset_list = list(self.offsets)
             offset_list[coord_id] += offset
-            return Field.Access(self.field, tuple(offset_list), self.index,
-                                is_absolute_access=self.is_absolute_access, dtype=self.dtype)
+            return Field.Access(
+                self.field,
+                tuple(offset_list),
+                self.index,
+                is_absolute_access=self.is_absolute_access,
+                dtype=self.dtype,
+            )
 
-        def get_shifted(self, *shift) -> 'Field.Access':
+        def get_shifted(self, *shift) -> "Field.Access":
             """Returns a new Access with changed spatial coordinates
 
             Example:
@@ -760,13 +1000,15 @@ class Field:
                 >>> f[0,0].get_shifted(1, 1)
                 f_NE
             """
-            return Field.Access(self.field,
-                                tuple(a + b for a, b in zip(shift, self.offsets)),
-                                self.index,
-                                is_absolute_access=self.is_absolute_access,
-                                dtype=self.dtype)
+            return Field.Access(
+                self.field,
+                tuple(a + b for a, b in zip(shift, self.offsets)),
+                self.index,
+                is_absolute_access=self.is_absolute_access,
+                dtype=self.dtype,
+            )
 
-        def at_index(self, *idx_tuple) -> 'Field.Access':
+        def at_index(self, *idx_tuple) -> "Field.Access":
             """Returns new Access with changed index.
 
             Example:
@@ -774,15 +1016,22 @@ class Field:
                 >>> f(0).at_index(8)
                 f_C^8
             """
-            return Field.Access(self.field, self.offsets, idx_tuple,
-                                is_absolute_access=self.is_absolute_access, dtype=self.dtype)
+            return Field.Access(
+                self.field,
+                self.offsets,
+                idx_tuple,
+                is_absolute_access=self.is_absolute_access,
+                dtype=self.dtype,
+            )
 
         def _eval_subs(self, old, new):
-            return Field.Access(self.field,
-                                tuple(sp.sympify(a).subs(old, new) for a in self.offsets),
-                                tuple(sp.sympify(a).subs(old, new) for a in self.index),
-                                is_absolute_access=self.is_absolute_access,
-                                dtype=self.dtype)
+            return Field.Access(
+                self.field,
+                tuple(sp.sympify(a).subs(old, new) for a in self.offsets),
+                tuple(sp.sympify(a).subs(old, new) for a in self.index),
+                is_absolute_access=self.is_absolute_access,
+                dtype=self.dtype,
+            )
 
         @property
         def is_absolute_access(self) -> bool:
@@ -790,30 +1039,43 @@ class Field:
             return self._is_absolute_access
 
         @property
-        def indirect_addressing_fields(self) -> Set['Field']:
+        def indirect_addressing_fields(self) -> Set["Field"]:
             """Returns a set of fields that the access depends on.
 
-             e.g. f[index_field[1, 0]], the outer access to f depends on index_field
-             """
+            e.g. f[index_field[1, 0]], the outer access to f depends on index_field
+            """
             return self._indirect_addressing_fields
 
         def _hashable_content(self):
             super_class_contents = super(Field.Access, self)._hashable_content()
-            return (super_class_contents, self._field.hashable_contents(), *self._index,
-                    *self._offsets, self._is_absolute_access)
+            return (
+                super_class_contents,
+                self._field.hashable_contents(),
+                *self._index,
+                *self._offsets,
+                self._is_absolute_access,
+            )
 
         def _staggered_offset(self, offsets, index):
             assert FieldType.is_staggered(self._field)
             neighbor = self._field.staggered_stencil[index]
-            neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions)
-            return [(o + sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)]
+            neighbor = direction_string_to_offset(
+                neighbor, self._field.spatial_dimensions
+            )
+            return [
+                (o + sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)
+            ]
 
         def _latex(self, _):
             n = self._field.latex_name if self._field.latex_name else self._field.name
             offset_str = ",".join([sp.latex(o) for o in self.offsets])
             if FieldType.is_staggered(self._field):
-                offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
-                                       for i in range(len(self.offsets))])
+                offset_str = ",".join(
+                    [
+                        sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
+                        for i in range(len(self.offsets))
+                    ]
+                )
             if self.is_absolute_access:
                 offset_str = f"\\mathbf{offset_str}"
             elif self.field.spatial_dimensions > 1:
@@ -834,8 +1096,12 @@ class Field:
             n = self._field.latex_name if self._field.latex_name else self._field.name
             offset_str = ",".join([sp.latex(o) for o in self.offsets])
             if FieldType.is_staggered(self._field):
-                offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
-                                       for i in range(len(self.offsets))])
+                offset_str = ",".join(
+                    [
+                        sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
+                        for i in range(len(self.offsets))
+                    ]
+                )
             if self.is_absolute_access:
                 offset_str = f"[abs]{offset_str}"
 
@@ -851,12 +1117,36 @@ class Field:
                     return f"{n}[{offset_str}]"
 
 
-def fields(description=None, index_dimensions=0, layout=None,
-           field_type=FieldType.GENERIC, **kwargs) -> Union[Field, List[Field]]:
+def fields(
+    description=None,
+    index_dimensions=0,
+    layout=None,
+    field_type=FieldType.GENERIC,
+    **kwargs,
+) -> Field | list[Field]:
     """Creates pystencils fields from a string description.
 
+    The description must be a string of the form
+    ``"name(index-shape) [name(index-shape) ...]: <data-type>[<dimension-or-shape>]"``,
+    where:
+
+    - ``name`` is the name of the respective field
+    - ``(index-shape)`` is a tuple of integers describing the shape of the tensor on each field node
+      (can be omitted for scalar fields)
+    - ``<data-type>`` is the numerical data type of the field's entries;
+      this can be any type parseable by `create_type`,
+      as well as ``dyn`` for `DynamicType.NUMERIC_TYPE`
+      and ``dynidx`` for `DynamicType.INDEX_TYPE`.
+    - ``<dimension-or-shape>`` can be a dimensionality (e.g. ``1D``, ``2D``, ``3D``)
+      or a tuple of integers defining the spatial shape of the field.
+
     Examples:
-        Create a 2D scalar and vector field:
+        Create a 3D scalar field of default numeric type:
+            >>> f = fields("f(1): [2D]")
+            >>> str(f.dtype)
+            'DynamicType.NUMERIC_TYPE'
+
+        Create a 2D scalar and vector field of 64-bit float type:
             >>> s, v = fields("s, v(2): double[2D]")
             >>> assert s.spatial_dimensions == 2 and s.index_dimensions == 0
             >>> assert (v.spatial_dimensions, v.index_dimensions, v.index_shape) == (2, 1, (2,))
@@ -882,35 +1172,70 @@ def fields(description=None, index_dimensions=0, layout=None,
             >>> f = fields("pdfs(19) : float32[3D]", layout='fzyx')
             >>> f.layout
             (2, 1, 0)
+
+    Returns:
+        Sequence of fields created from the description
     """
     result = []
     if description:
         field_descriptions, dtype, shape = _parse_description(description)
-        layout = 'numpy' if layout is None else layout
+        layout = "numpy" if layout is None else layout
         for field_name, idx_shape in field_descriptions:
             if field_name in kwargs:
                 arr = kwargs[field_name]
-                idx_shape_of_arr = () if not len(idx_shape) else arr.shape[-len(idx_shape):]
+                idx_shape_of_arr = (
+                    () if not len(idx_shape) else arr.shape[-len(idx_shape):]
+                )
                 assert idx_shape_of_arr == idx_shape
-                f = Field.create_from_numpy_array(field_name, kwargs[field_name], index_dimensions=len(idx_shape),
-                                                  field_type=field_type)
+                f = Field.create_from_numpy_array(
+                    field_name,
+                    kwargs[field_name],
+                    index_dimensions=len(idx_shape),
+                    field_type=field_type,
+                )
             elif isinstance(shape, tuple):
-                f = Field.create_fixed_size(field_name, shape + idx_shape, dtype=dtype,
-                                            index_dimensions=len(idx_shape), layout=layout, field_type=field_type)
+                f = Field.create_fixed_size(
+                    field_name,
+                    shape + idx_shape,
+                    dtype=dtype,
+                    index_dimensions=len(idx_shape),
+                    layout=layout,
+                    field_type=field_type,
+                )
             elif isinstance(shape, int):
-                f = Field.create_generic(field_name, spatial_dimensions=shape, dtype=dtype,
-                                         index_shape=idx_shape, layout=layout, field_type=field_type)
+                f = Field.create_generic(
+                    field_name,
+                    spatial_dimensions=shape,
+                    dtype=dtype,
+                    index_shape=idx_shape,
+                    layout=layout,
+                    field_type=field_type,
+                )
             elif shape is None:
-                f = Field.create_generic(field_name, spatial_dimensions=2, dtype=dtype,
-                                         index_shape=idx_shape, layout=layout, field_type=field_type)
+                f = Field.create_generic(
+                    field_name,
+                    spatial_dimensions=2,
+                    dtype=dtype,
+                    index_shape=idx_shape,
+                    layout=layout,
+                    field_type=field_type,
+                )
             else:
                 assert False
             result.append(f)
     else:
-        assert layout is None, "Layout can not be specified when creating Field from numpy array"
+        assert (
+            layout is None
+        ), "Layout can not be specified when creating Field from numpy array"
         for field_name, arr in kwargs.items():
-            result.append(Field.create_from_numpy_array(field_name, arr, index_dimensions=index_dimensions,
-                                                        field_type=field_type))
+            result.append(
+                Field.create_from_numpy_array(
+                    field_name,
+                    arr,
+                    index_dimensions=index_dimensions,
+                    field_type=field_type,
+                )
+            )
 
     if len(result) == 0:
         raise ValueError("Could not parse field description")
@@ -920,16 +1245,27 @@ def fields(description=None, index_dimensions=0, layout=None,
         return result
 
 
-def get_layout_from_strides(strides: Sequence[int], index_dimension_ids: Optional[List[int]] = None):
+def get_layout_from_strides(
+    strides: Sequence[int], index_dimension_ids: Optional[List[int]] = None
+):
     index_dimension_ids = [] if index_dimension_ids is None else index_dimension_ids
     coordinates = list(range(len(strides)))
-    relevant_strides = [stride for i, stride in enumerate(strides) if i not in index_dimension_ids]
-    result = [x for (y, x) in sorted(zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True)]
+    relevant_strides = [
+        stride for i, stride in enumerate(strides) if i not in index_dimension_ids
+    ]
+    result = [
+        x
+        for (y, x) in sorted(
+            zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True
+        )
+    ]
     return normalize_layout(result)
 
 
-def get_layout_of_array(arr: np.ndarray, index_dimension_ids: Optional[List[int]] = None):
-    """ Returns a list indicating the memory layout (linearization order) of the numpy array.
+def get_layout_of_array(
+    arr: np.ndarray, index_dimension_ids: Optional[List[int]] = None
+):
+    """Returns a list indicating the memory layout (linearization order) of the numpy array.
 
     Examples:
         >>> get_layout_of_array(np.zeros([3,3,3]))
@@ -946,7 +1282,9 @@ def get_layout_of_array(arr: np.ndarray, index_dimension_ids: Optional[List[int]
     return get_layout_from_strides(arr.strides, index_dimension_ids)
 
 
-def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0, **kwargs):
+def create_numpy_array_with_layout(
+    shape, layout, alignment=False, byte_offset=0, **kwargs
+):
     """Creates numpy array with given memory layout.
 
     Args:
@@ -970,7 +1308,10 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0
         if cur_layout[i] != layout[i]:
             index_to_swap_with = cur_layout.index(layout[i])
             swaps.append((i, index_to_swap_with))
-            cur_layout[i], cur_layout[index_to_swap_with] = cur_layout[index_to_swap_with], cur_layout[i]
+            cur_layout[i], cur_layout[index_to_swap_with] = (
+                cur_layout[index_to_swap_with],
+                cur_layout[i],
+            )
     assert tuple(cur_layout) == tuple(layout)
 
     shape = list(shape)
@@ -978,7 +1319,7 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0
         shape[a], shape[b] = shape[b], shape[a]
 
     if not alignment:
-        res = np.empty(shape, order='c', **kwargs)
+        res = np.empty(shape, order="c", **kwargs)
     else:
         res = aligned_empty(shape, alignment, byte_offset=byte_offset, **kwargs)
 
@@ -990,37 +1331,43 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0
 def spatial_layout_string_to_tuple(layout_str: str, dim: int) -> Tuple[int, ...]:
     if dim <= 0:
         raise ValueError("Dimensionality must be positive")
-    
+
     layout_str = layout_str.lower()
 
-    if layout_str in ('fzyx', 'zyxf', 'soa', 'aos'):
+    if layout_str in ("fzyx", "zyxf", "soa", "aos"):
         if dim > 3:
-            raise ValueError(f"Invalid spatial dimensionality for layout descriptor {layout_str}: May be at most 3.")
+            raise ValueError(
+                f"Invalid spatial dimensionality for layout descriptor {layout_str}: May be at most 3."
+            )
         return tuple(reversed(range(dim)))
-    
-    if layout_str in ('f', 'reverse_numpy'):
+
+    if layout_str in ("f", "reverse_numpy"):
         return tuple(reversed(range(dim)))
-    elif layout_str in ('c', 'numpy'):
+    elif layout_str in ("c", "numpy"):
         return tuple(range(dim))
     raise ValueError("Unknown layout descriptor " + layout_str)
 
 
-def layout_string_to_tuple(layout_str, dim):
+def layout_string_to_tuple(layout_str, dim) -> tuple[int, ...]:
     if dim <= 0:
         raise ValueError("Dimensionality must be positive")
-    
+
     layout_str = layout_str.lower()
-    if layout_str == 'fzyx' or layout_str == 'soa':
+    if layout_str == "fzyx" or layout_str == "soa":
         if dim > 4:
-            raise ValueError(f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4.")
+            raise ValueError(
+                f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4."
+            )
         return tuple(reversed(range(dim)))
-    elif layout_str == 'zyxf' or layout_str == 'aos':
+    elif layout_str == "zyxf" or layout_str == "aos":
         if dim > 4:
-            raise ValueError(f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4.")
+            raise ValueError(
+                f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4."
+            )
         return tuple(reversed(range(dim - 1))) + (dim - 1,)
-    elif layout_str == 'f' or layout_str == 'reverse_numpy':
+    elif layout_str == "f" or layout_str == "reverse_numpy":
         return tuple(reversed(range(dim)))
-    elif layout_str == 'c' or layout_str == 'numpy':
+    elif layout_str == "c" or layout_str == "numpy":
         return tuple(range(dim))
     raise ValueError("Unknown layout descriptor " + layout_str)
 
@@ -1055,7 +1402,8 @@ def compute_strides(shape, layout):
 
 # ---------------------------------------- Parsing of string in fields() function --------------------------------------
 
-field_description_regex = re.compile(r"""
+field_description_regex = re.compile(
+    r"""
     \s*                 # ignore leading white spaces
     (\w+)               # identifier is a sequence of alphanumeric characters, is stored in first group
     (?:                 # optional index specification e.g. (1, 4, 2)
@@ -1066,9 +1414,12 @@ field_description_regex = re.compile(r"""
         \s*
     )?
     \s*,?\s*             # ignore trailing white spaces and comma
-""", re.VERBOSE)
+""",
+    re.VERBOSE,
+)
 
-type_description_regex = re.compile(r"""
+type_description_regex = re.compile(
+    r"""
     \s*
     (\w+)?       # optional dtype
     \s*
@@ -1076,7 +1427,9 @@ type_description_regex = re.compile(r"""
         ([^\]]+)
     \]
     \s*
-""", re.VERBOSE | re.IGNORECASE)
+""",
+    re.VERBOSE | re.IGNORECASE,
+)
 
 
 def _parse_part1(d):
@@ -1094,24 +1447,30 @@ def _parse_description(description):
         result = type_description_regex.match(d)
         if result:
             data_type_str, size_info = result.group(1), result.group(2).strip().lower()
-            if data_type_str is None:
-                data_type_str = 'float64'
-            data_type_str = data_type_str.lower().strip()
+            if data_type_str is not None:
+                data_type_str = data_type_str.lower().strip()
+
+            if data_type_str:
+                match data_type_str:
+                    case "dyn": dtype = DynamicType.NUMERIC_TYPE
+                    case "dynidx": dtype = DynamicType.INDEX_TYPE
+                    case _: dtype = create_type(data_type_str)
+            else:
+                dtype = DynamicType.NUMERIC_TYPE
 
-            if not data_type_str:
-                data_type_str = 'float64'
-            if size_info.endswith('d'):
+            if size_info.endswith("d"):
                 size_info = int(size_info[:-1])
             else:
                 size_info = tuple(int(e) for e in size_info.split(","))
-            return data_type_str, size_info
+
+            return dtype, size_info
         else:
             raise ValueError("Could not parse field description")
 
-    if ':' in description:
-        field_description, field_info = description.split(':')
+    if ":" in description:
+        field_description, field_info = description.split(":")
     else:
-        field_description, field_info = description, 'float64[2D]'
+        field_description, field_info = description, "float64[2D]"
 
     fields_info = [e for e in _parse_part1(field_description)]
     if not field_info:
diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py
index befb033e6f7969a5ffd9bc7742e9e7ab691da47d..55f1961ca5c00963c16912ada738788688a93452 100644
--- a/src/pystencils/jit/cpu_extension_module.py
+++ b/src/pystencils/jit/cpu_extension_module.py
@@ -1,6 +1,6 @@
 from __future__ import annotations
 
-from typing import Any
+from typing import Any, cast
 
 from os import path
 import hashlib
@@ -19,6 +19,7 @@ from ..types import (
     PsUnsignedIntegerType,
     PsSignedIntegerType,
     PsIeeeFloatType,
+    PsPointerType,
 )
 from ..types.quick import Fp, SInt, UInt
 from ..field import Field
@@ -121,9 +122,7 @@ class PsKernelExtensioNModule:
 
 def emit_call_wrapper(function_name: str, kernel: Kernel) -> str:
     builder = CallWrapperBuilder()
-
-    for p in kernel.parameters:
-        builder.extract_parameter(p)
+    builder.extract_params(kernel.parameters)
 
     # for c in kernel.constraints:
     #     builder.check_constraint(c)
@@ -199,7 +198,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
 """
 
     def __init__(self) -> None:
-        self._array_buffers: dict[Field, str] = dict()
+        self._buffer_types: dict[Field, PsType] = dict()
         self._array_extractions: dict[Field, str] = dict()
         self._array_frees: dict[Field, str] = dict()
 
@@ -220,9 +219,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
                 return "PyLong_AsUnsignedLong"
 
             case _:
-                raise ValueError(
-                    f"Don't know how to cast Python objects to {dtype}"
-                )
+                raise ValueError(f"Don't know how to cast Python objects to {dtype}")
 
     def _type_char(self, dtype: PsType) -> str | None:
         if isinstance(
@@ -233,37 +230,39 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
         else:
             return None
 
-    def extract_field(self, field: Field) -> str:
-        """Adds an array, and returns the name of the underlying Py_Buffer."""
+    def get_field_buffer(self, field: Field) -> str:
+        """Get the Python buffer object for the given field."""
+        return f"buffer_{field.name}"
+
+    def extract_field(self, field: Field) -> None:
+        """Add the necessary code to extract the NumPy array for a given field"""
         if field not in self._array_extractions:
             extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=field.name)
+            actual_dtype = self._buffer_types[field]
 
             #   Check array type
-            type_char = self._type_char(field.dtype)
+            type_char = self._type_char(actual_dtype)
             if type_char is not None:
                 dtype_cond = f"buffer_{field.name}.format[0] == '{type_char}'"
                 extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
                     cond=dtype_cond,
                     what="data type",
                     name=field.name,
-                    expected=str(field.dtype),
+                    expected=str(actual_dtype),
                 )
 
             #   Check item size
-            itemsize = field.dtype.itemsize
+            itemsize = actual_dtype.itemsize
             item_size_cond = f"buffer_{field.name}.itemsize == {itemsize}"
             extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
                 cond=item_size_cond, what="itemsize", name=field.name, expected=itemsize
             )
 
-            self._array_buffers[field] = f"buffer_{field.name}"
             self._array_extractions[field] = extraction_code
 
             release_code = f"PyBuffer_Release(&buffer_{field.name});"
             self._array_frees[field] = release_code
 
-        return self._array_buffers[field]
-
     def extract_scalar(self, param: Parameter) -> str:
         if param not in self._scalar_extractions:
             extract_func = self._scalar_extractor(param.dtype)
@@ -279,7 +278,8 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
     def extract_array_assoc_var(self, param: Parameter) -> str:
         if param not in self._array_assoc_var_extractions:
             field = param.fields[0]
-            buffer = self.extract_field(field)
+            buffer = self.get_field_buffer(field)
+            buffer_dtype = self._buffer_types[field]
             code: str | None = None
 
             for prop in param.properties:
@@ -293,7 +293,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
                     case FieldStride(_, coord):
                         code = (
                             f"{param.dtype.c_string()} {param.name} = "
-                            f"{buffer}.strides[{coord}] / {field.dtype.itemsize};"
+                            f"{buffer}.strides[{coord}] / {buffer_dtype.itemsize};"
                         )
                         break
             assert code is not None
@@ -302,29 +302,48 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
 
         return param.name
 
-    def extract_parameter(self, param: Parameter):
-        if param.is_field_parameter:
-            self.extract_array_assoc_var(param)
-        else:
-            self.extract_scalar(param)
+    def extract_params(self, params: tuple[Parameter, ...]) -> None:
+        for param in params:
+            if ptr_props := param.get_properties(FieldBasePtr):
+                prop: FieldBasePtr = cast(FieldBasePtr, ptr_props.pop())
+                field = prop.field
+                actual_field_type: PsType
+
+                from .. import DynamicType
+
+                if isinstance(field.dtype, DynamicType):
+                    ptr_type = param.dtype
+                    assert isinstance(ptr_type, PsPointerType)
+                    actual_field_type = ptr_type.base_type
+                else:
+                    actual_field_type = field.dtype
+
+                self._buffer_types[prop.field] = actual_field_type
+                self.extract_field(prop.field)
+
+        for param in params:
+            if param.is_field_parameter:
+                self.extract_array_assoc_var(param)
+            else:
+                self.extract_scalar(param)
 
-#     def check_constraint(self, constraint: KernelParamsConstraint):
-#         variables = constraint.get_parameters()
+    #     def check_constraint(self, constraint: KernelParamsConstraint):
+    #         variables = constraint.get_parameters()
 
-#         for var in variables:
-#             self.extract_parameter(var)
+    #         for var in variables:
+    #             self.extract_parameter(var)
 
-#         cond = constraint.to_code()
+    #         cond = constraint.to_code()
 
-#         code = f"""
-# if(!({cond}))
-# {{
-#     PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}"); 
-#     return NULL;
-# }}
-# """
+    #         code = f"""
+    # if(!({cond}))
+    # {{
+    #     PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}");
+    #     return NULL;
+    # }}
+    # """
 
-#         self._constraint_checks.append(code)
+    #         self._constraint_checks.append(code)
 
     def call(self, kernel: Kernel, params: tuple[Parameter, ...]):
         param_list = ", ".join(p.name for p in params)
diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py
index c208ac2196151d079ca5081f1377c55d18a9393c..a407bb75e08bfde9911070aef03b4a1769a6221a 100644
--- a/src/pystencils/jit/gpu_cupy.py
+++ b/src/pystencils/jit/gpu_cupy.py
@@ -19,7 +19,7 @@ from ..codegen import (
     Parameter,
 )
 from ..codegen.properties import FieldShape, FieldStride, FieldBasePtr
-from ..types import PsStructType
+from ..types import PsStructType, PsPointerType
 
 from ..include import get_pystencils_include_path
 
@@ -160,8 +160,18 @@ class CupyKernelWrapper(KernelWrapper):
                 for prop in kparam.properties:
                     match prop:
                         case FieldBasePtr(field):
+
+                            elem_dtype: PsType
+
+                            from .. import DynamicType
+                            if isinstance(field.dtype, DynamicType):
+                                assert isinstance(kparam.dtype, PsPointerType)
+                                elem_dtype = kparam.dtype.base_type
+                            else:
+                                elem_dtype = field.dtype
+
                             arr = kwargs[field.name]
-                            if arr.dtype != field.dtype.numpy_dtype:
+                            if arr.dtype != elem_dtype.numpy_dtype:
                                 raise JitError(
                                     f"Data type mismatch at array argument {field.name}:"
                                     f"Expected {field.dtype}, got {arr.dtype}"
diff --git a/src/pystencils/rng.py b/src/pystencils/rng.py
index d6c6cd2741ee3e7442bd9fa4a96f4e9983d524e3..4f8316fa75284ed0fa3385744bd9b93f88d5ae65 100644
--- a/src/pystencils/rng.py
+++ b/src/pystencils/rng.py
@@ -2,7 +2,7 @@ import copy
 import numpy as np
 import sympy as sp
 
-from .sympyextensions import TypedSymbol, CastFunc, fast_subs
+from .sympyextensions import TypedSymbol, tcast, fast_subs
 # from pystencils.sympyextensions.astnodes import LoopOverCoordinate # TODO nbackend: replace
 # from pystencils.backends.cbackend import CustomCodeNode # TODO nbackend: replace
 
@@ -48,7 +48,7 @@ class RNGBase:
     def get_code(self, dialect, vector_instruction_set, print_arg):
         code = "\n"
         for r in self.result_symbols:
-            if vector_instruction_set and not self.args[1].atoms(CastFunc):
+            if vector_instruction_set and not self.args[1].atoms(tcast):
                 # this vector RNG has become scalar through substitution
                 code += f"{r.dtype} {r.name};\n"
             else:
diff --git a/src/pystencils/sympyextensions/__init__.py b/src/pystencils/sympyextensions/__init__.py
index 7431416c9eb9bcd4433dab76c32fb1b755501105..2d874fdc0778a331aaf61ed938981f533eafbecb 100644
--- a/src/pystencils/sympyextensions/__init__.py
+++ b/src/pystencils/sympyextensions/__init__.py
@@ -1,5 +1,5 @@
 from .astnodes import ConditionalFieldAccess
-from .typed_sympy import TypedSymbol, CastFunc
+from .typed_sympy import TypedSymbol, tcast
 from .pointers import mem_acc
 
 from .math import (
@@ -34,7 +34,7 @@ from .math import (
 __all__ = [
     "ConditionalFieldAccess",
     "TypedSymbol",
-    "CastFunc",
+    "tcast",
     "mem_acc",
     "remove_higher_order_terms",
     "prod",
diff --git a/src/pystencils/sympyextensions/math.py b/src/pystencils/sympyextensions/math.py
index 9841a98bd83162fbb080db370556de70612bc398..33c035499ee80598303c8a26b028e47dfae72cc3 100644
--- a/src/pystencils/sympyextensions/math.py
+++ b/src/pystencils/sympyextensions/math.py
@@ -11,7 +11,7 @@ from sympy.functions import Abs
 from sympy.core.numbers import Zero
 
 from ..assignment import Assignment
-from .typed_sympy import CastFunc
+from .typed_sympy import TypeCast
 from ..types import PsPointerType, PsVectorType
 
 T = TypeVar('T')
@@ -603,7 +603,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
             visit_children = False
         elif t.is_integer:
             pass
-        elif isinstance(t, CastFunc):
+        elif isinstance(t, TypeCast):
             visit_children = False
             visit(t.args[0])
         elif t.func is fast_sqrt:
diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py
index 39202296b477dc17ea6e9564548ef841fd04594d..e2435d6bbe570887e0903c67f6041ed9911c02be 100644
--- a/src/pystencils/sympyextensions/typed_sympy.py
+++ b/src/pystencils/sympyextensions/typed_sympy.py
@@ -1,4 +1,5 @@
 from __future__ import annotations
+from typing import cast
 
 import sympy as sp
 from enum import Enum, auto
@@ -6,11 +7,14 @@ from enum import Enum, auto
 from ..types import (
     PsType,
     PsNumericType,
-    PsBoolType,
     create_type,
     UserTypeSpec
 )
 
+from sympy.logic.boolalg import Boolean
+
+from warnings import warn
+
 
 def is_loop_counter_symbol(symbol):
     from ..defaults import DEFAULTS
@@ -37,11 +41,12 @@ class DynamicType(Enum):
 class TypeAtom(sp.Atom):
     """Wrapper around a type to disguise it as a SymPy atom."""
 
-    def __new__(cls, *args, **kwargs):
-        return sp.Basic.__new__(cls)
+    _dtype: PsType | DynamicType
 
-    def __init__(self, dtype: PsType | DynamicType) -> None:
-        self._dtype = dtype
+    def __new__(cls, dtype: PsType | DynamicType):
+        obj = super().__new__(cls)
+        obj._dtype = dtype
+        return obj
 
     def _sympystr(self, *args, **kwargs):
         return str(self._dtype)
@@ -52,6 +57,9 @@ class TypeAtom(sp.Atom):
     def _hashable_content(self):
         return (self._dtype,)
     
+    def __getnewargs__(self):
+        return (self._dtype,)
+    
 
 def assumptions_from_dtype(dtype: PsType | DynamicType):
     """Derives SymPy assumptions from :class:`PsAbstractType`
@@ -133,144 +141,74 @@ class TypedSymbol(sp.Symbol):
         return self.dtype.required_headers if isinstance(self.dtype, PsType) else set()
 
 
-class CastFunc(sp.Function):
-    """Use this function to introduce a static type cast into the output code.
-
-    Usage: ``CastFunc(expr, target_type)`` becomes, in C code, ``(target_type) expr``.
-    The ``target_type`` may be a valid pystencils type specification parsable by `create_type`,
-    or a special value of the `DynamicType` enum.
-    These dynamic types can be used to select the target type according to the code generation context.
-    """
+class TypeCast(sp.Function):
+    """Explicitly cast an expression to a data type."""
 
     @staticmethod
     def as_numeric(expr):
-        return CastFunc(expr, DynamicType.NUMERIC_TYPE)
+        return TypeCast(expr, DynamicType.NUMERIC_TYPE)
 
     @staticmethod
     def as_index(expr):
-        return CastFunc(expr, DynamicType.INDEX_TYPE)
-
-    is_Atom = True
-
-    def __new__(cls, *args, **kwargs):
-        if len(args) != 2:
-            pass
-        expr, dtype, *other_args = args
-
-        # If we have two consecutive casts, throw the inner one away.
-        # This optimisation is only available for simple casts. Thus the == is intended here!
-        if expr.__class__ == CastFunc:
-            expr = expr.args[0]
-
-        if not isinstance(dtype, (TypeAtom)):
-            if isinstance(dtype, DynamicType):
-                dtype = TypeAtom(dtype)
-            else:
-                dtype = TypeAtom(create_type(dtype))
-
-        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
-        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
-        # to problems when for example comparing cast_func's for equality
-        #
-        # lhs = bitwise_and(a, cast_func(1, 'int'))
-        # rhs = cast_func(0, 'int')
-        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
-        # -> thus a separate class boolean_cast_func is introduced
-        if isinstance(expr, sp.logic.boolalg.Boolean) and (
-            not isinstance(expr, TypedSymbol) or isinstance(expr.dtype, PsBoolType)
-        ):
-            cls = BooleanCastFunc
-
-        return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
-
-    @property
-    def canonical(self):
-        if hasattr(self.args[0], "canonical"):
-            return self.args[0].canonical
-        else:
-            raise NotImplementedError()
-
-    @property
-    def is_commutative(self):
-        return self.args[0].is_commutative
-
-    @property
-    def dtype(self) -> PsType | DynamicType:
-        assert isinstance(self.args[1], TypeAtom)
-        return self.args[1].get()
-
+        return TypeCast(expr, DynamicType.INDEX_TYPE)
+    
     @property
-    def expr(self):
+    def expr(self) -> sp.Basic:
         return self.args[0]
 
     @property
-    def is_integer(self):
+    def dtype(self) -> PsType | DynamicType:
+        return cast(TypeAtom, self._args[1]).get()
+    
+    def __new__(cls, expr: sp.Basic, dtype: UserTypeSpec | DynamicType | TypeAtom):
+        tatom: TypeAtom
+        match dtype:
+            case TypeAtom():
+                tatom = dtype
+            case DynamicType():
+                tatom = TypeAtom(dtype)
+            case _:
+                tatom = TypeAtom(create_type(dtype))
+        
+        return super().__new__(cls, expr, tatom)
+    
+    @classmethod
+    def eval(cls, expr: sp.Basic, tatom: TypeAtom) -> TypeCast | None:
+        dtype = tatom.get()
+        if cls is not BoolCast and isinstance(dtype, PsNumericType) and dtype.is_bool():
+            return BoolCast(expr, tatom)
+        
+        return None
+    
+    def _eval_is_integer(self):
         if self.dtype == DynamicType.INDEX_TYPE:
             return True
-        elif isinstance(self.dtype, PsNumericType):
-            return self.dtype.is_int() or super().is_integer
-        else:
-            return super().is_integer
-
-    @property
-    def is_negative(self):
-        """
-        See :func:`.TypedSymbol.is_integer`
-        """
-        if isinstance(self.dtype, PsNumericType):
-            if self.dtype.is_uint():
-                return False
-
-        return super().is_negative
-
-    @property
-    def is_nonnegative(self):
-        """
-        See :func:`.TypedSymbol.is_integer`
-        """
-        if self.is_negative is False:
+        if isinstance(self.dtype, PsNumericType) and self.dtype.is_int():
+            return True
+        
+    def _eval_is_real(self):
+        if isinstance(self.dtype, DynamicType):
+            return True
+        if isinstance(self.dtype, PsNumericType) and (self.dtype.is_float() or self.dtype.is_int()):
+            return True
+        
+    def _eval_is_nonnegative(self):
+        if isinstance(self.dtype, PsNumericType) and self.dtype.is_uint():
             return True
-        else:
-            return super().is_nonnegative
-
-    @property
-    def is_real(self):
-        """
-        See :func:`.TypedSymbol.is_integer`
-        """
-        if isinstance(self.dtype, PsNumericType):
-            return self.dtype.is_int() or self.dtype.is_float() or super().is_real
-        else:
-            return super().is_real
-
-
-class BooleanCastFunc(CastFunc, sp.logic.boolalg.Boolean):
-    # TODO: documentation
-    pass
-
-
-class VectorMemoryAccess(CastFunc):
-    """
-    Special memory access for vectorized kernel.
-    Arguments: read/write expression, type, aligned, non-temporal, mask (or none), stride
-    """
 
-    nargs = (6,)
 
+class BoolCast(TypeCast, Boolean):
+    pass
 
-class ReinterpretCastFunc(CastFunc):
-    """
-    Reinterpret cast is necessary for the StructType
-    """
 
-    pass
+tcast = TypeCast
 
 
-class PointerArithmeticFunc(sp.Function, sp.logic.boolalg.Boolean):
-    # TODO: documentation, or deprecate!
-    @property
-    def canonical(self):
-        if hasattr(self.args[0], "canonical"):
-            return self.args[0].canonical
-        else:
-            raise NotImplementedError()
+class CastFunc(TypeCast):
+    def __new__(cls, *args, **kwargs):
+        warn(
+            "CastFunc is deprecated and will be removed in pystencils 2.1. "
+            "Use `pystencils.tcast` instead.",
+            FutureWarning
+        )
+        return TypeCast.__new__(cls, *args, **kwargs)
diff --git a/tests/frontend/test_address_of.py b/tests/frontend/test_address_of.py
index 99f33ddbdfa7054bf5f27c08848640ee03f64555..62d7f00d56b288c009c9dc4fcfade95b95acdd41 100644
--- a/tests/frontend/test_address_of.py
+++ b/tests/frontend/test_address_of.py
@@ -5,7 +5,7 @@ import pytest
 import pystencils
 from pystencils.types import PsPointerType, create_type
 from pystencils.sympyextensions.pointers import AddressOf
-from pystencils.sympyextensions.typed_sympy import CastFunc
+from pystencils.sympyextensions.typed_sympy import tcast
 from pystencils.simp import sympy_cse
 
 import sympy as sp
@@ -23,14 +23,14 @@ def test_address_of():
 
     assignments = pystencils.AssignmentCollection({
         s: AddressOf(x[0, 0]),
-        y[0, 0]: CastFunc(s, create_type('int64'))
+        y[0, 0]: tcast(s, create_type('int64'))
     })
 
     _ = pystencils.create_kernel(assignments).compile()
     # pystencils.show_code(kernel.ast)
 
     assignments = pystencils.AssignmentCollection({
-        y[0, 0]: CastFunc(AddressOf(x[0, 0]), create_type('int64'))
+        y[0, 0]: tcast(AddressOf(x[0, 0]), create_type('int64'))
     })
 
     _ = pystencils.create_kernel(assignments).compile()
@@ -41,7 +41,7 @@ def test_address_of_with_cse():
     x, y = pystencils.fields('x, y: int64[2d]')
 
     assignments = pystencils.AssignmentCollection({
-        x[0, 0]: CastFunc(AddressOf(x[0, 0]), create_type('int64')) + 1
+        x[0, 0]: tcast(AddressOf(x[0, 0]), create_type('int64')) + 1
     })
 
     _ = pystencils.create_kernel(assignments).compile()
diff --git a/tests/frontend/test_field.py b/tests/frontend/test_field.py
index dc804491bee8023e7b0e1b665d5f9cd252d64c1d..6d2942569704b7ff85b15fd23432667ba109ed7d 100644
--- a/tests/frontend/test_field.py
+++ b/tests/frontend/test_field.py
@@ -3,7 +3,7 @@ import pytest
 import sympy as sp
 
 import pystencils as ps
-from pystencils import DEFAULTS
+from pystencils import DEFAULTS, DynamicType, create_type, fields
 from pystencils.field import (
     Field,
     FieldType,
@@ -15,6 +15,7 @@ from pystencils.field import (
 def test_field_basic():
     f = Field.create_generic("f", spatial_dimensions=2)
     assert FieldType.is_generic(f)
+    assert f.dtype == DynamicType.NUMERIC_TYPE
     assert f["E"] == f[1, 0]
     assert f["N"] == f[0, 1]
     assert "_" in f.center._latex("dummy")
@@ -41,17 +42,16 @@ def test_field_basic():
     assert f1.ndim == f.ndim
     assert f1.values_per_cell() == f.values_per_cell()
 
-    fixed = ps.fields("f(5, 5) : double[20, 20]")
-    assert fixed.neighbor_vector((1, 1)).shape == (5, 5)
-
-    f = Field.create_fixed_size("f", (10, 10), strides=(80, 8), dtype=np.float64)
+    f = Field.create_fixed_size("f", (10, 10), strides=(10, 1), dtype=np.float64)
     assert f.spatial_strides == (10, 1)
     assert f.index_strides == ()
     assert f.center_vector == sp.Matrix([f.center])
+    assert f.dtype == create_type("float64")
 
     f1 = f.new_field_with_different_name("f1")
     assert f1.ndim == f.ndim
     assert f1.values_per_cell() == f.values_per_cell()
+    assert f1.dtype == create_type("float64")
 
     f = Field.create_fixed_size("f", (8, 8, 2, 2), index_dimensions=2)
     assert f.center_vector == sp.Matrix([[f(0, 0), f(0, 1)], [f(1, 0), f(1, 1)]])
@@ -61,16 +61,48 @@ def test_field_basic():
     neighbor = field_access.neighbor(coord_id=0, offset=-2)
     assert neighbor.offsets == (-1, 1)
     assert "_" in neighbor._latex("dummy")
+    assert f.dtype == DynamicType.NUMERIC_TYPE
 
     f = Field.create_fixed_size("f", (8, 8, 2, 2, 2), index_dimensions=3)
     assert f.center_vector == sp.Array(
         [[[f(i, j, k) for k in range(2)] for j in range(2)] for i in range(2)]
     )
+    assert f.dtype == DynamicType.NUMERIC_TYPE
 
     f = Field.create_generic("f", spatial_dimensions=5, index_dimensions=2)
     field_access = f[1, -1, 2, -3, 0](1, 0)
     assert field_access.offsets == (1, -1, 2, -3, 0)
     assert field_access.index == (1, 0)
+    assert f.dtype == DynamicType.NUMERIC_TYPE
+
+
+def test_field_description_parsing():
+    f, g = fields("f(1), g(3): [2D]")
+
+    assert f.dtype == g.dtype == DynamicType.NUMERIC_TYPE
+    assert f.spatial_dimensions == g.spatial_dimensions == 2
+    assert f.index_shape == (1,)
+    assert g.index_shape == (3,)
+
+    f = fields("f: dyn[3D]")
+    assert f.dtype == DynamicType.NUMERIC_TYPE
+
+    idx = fields("idx: dynidx[3D]")
+    assert idx.dtype == DynamicType.INDEX_TYPE
+
+    h = fields("h: float32[3D]")
+
+    assert h.index_shape == ()
+    assert h.spatial_dimensions == 3
+    assert h.index_dimensions == 0
+    assert h.dtype == create_type("float32")
+
+    f: Field = fields("f(5, 5) : double[20, 20]")
+    
+    assert f.dtype == create_type("float64")
+    assert f.spatial_shape == (20, 20)
+    assert f.index_shape == (5, 5)
+    assert f.neighbor_vector((1, 1)).shape == (5, 5)
 
 
 def test_error_handling():
@@ -145,7 +177,7 @@ def test_error_handling():
 
 
 def test_decorator_scoping():
-    dst = ps.fields("dst : double[2D]")
+    dst = fields("dst : double[2D]")
 
     def f1():
         a = sp.Symbol("a")
@@ -165,7 +197,7 @@ def test_decorator_scoping():
 
 
 def test_string_creation():
-    x, y, z = ps.fields("  x(4),    y(3,5) z : double[  3,  47]")
+    x, y, z = fields("  x(4),    y(3,5) z : double[  3,  47]")
     assert x.index_shape == (4,)
     assert y.index_shape == (3, 5)
     assert z.spatial_shape == (3, 47)
@@ -173,9 +205,9 @@ def test_string_creation():
 
 def test_itemsize():
 
-    x = ps.fields("x: float32[1d]")
-    y = ps.fields("y:  float64[2d]")
-    i = ps.fields("i:  int16[1d]")
+    x = fields("x: float32[1d]")
+    y = fields("y:  float64[2d]")
+    i = fields("i:  int16[1d]")
 
     assert x.itemsize == 4
     assert y.itemsize == 8
@@ -249,7 +281,7 @@ def test_memory_layout_descriptors():
 def test_staggered():
 
     # D2Q5
-    j1, j2, j3 = ps.fields(
+    j1, j2, j3 = fields(
         "j1(2), j2(2,2), j3(2,2,2) : double[2D]", field_type=FieldType.STAGGERED
     )
 
@@ -296,7 +328,7 @@ def test_staggered():
     )
 
     # D2Q9
-    k1, k2 = ps.fields("k1(4), k2(2) : double[2D]", field_type=FieldType.STAGGERED)
+    k1, k2 = fields("k1(4), k2(2) : double[2D]", field_type=FieldType.STAGGERED)
 
     assert k1[1, 1](2) == k1.staggered_access("NE")
     assert k1[0, 0](2) == k1.staggered_access("SW")
@@ -319,7 +351,7 @@ def test_staggered():
     ]
 
     # sign reversed when using as flux field
-    r = ps.fields("r(2) : double[2D]", field_type=FieldType.STAGGERED_FLUX)
+    r = fields("r(2) : double[2D]", field_type=FieldType.STAGGERED_FLUX)
     assert r[0, 0](0) == r.staggered_access("W")
     assert -r[1, 0](0) == r.staggered_access("E")
 
diff --git a/tests/frontend/test_typed_sympy.py b/tests/frontend/test_typed_sympy.py
index 41015f96bfa6950a57f9ccfa3194c128c2bc0f69..bf6058537a7217851d22987f3b011edea08058c8 100644
--- a/tests/frontend/test_typed_sympy.py
+++ b/tests/frontend/test_typed_sympy.py
@@ -1,8 +1,11 @@
 import numpy as np
+import pickle
+import sympy as sp
+from sympy.logic import boolalg
 
 from pystencils.sympyextensions.typed_sympy import (
     TypedSymbol,
-    CastFunc,
+    tcast,
     TypeAtom,
     DynamicType,
 )
@@ -12,7 +15,7 @@ from pystencils.types.quick import UInt, Ptr
 
 def test_type_atoms():
     atom1 = TypeAtom(create_type("int32"))
-    atom2 = TypeAtom(create_type("int32"))
+    atom2 = TypeAtom(create_type(np.int32))
 
     assert atom1 == atom2
 
@@ -25,6 +28,11 @@ def test_type_atoms():
     assert atom3 != atom4
     assert atom4 != atom5
 
+    dump = pickle.dumps(atom1)
+    atom1_reconst = pickle.loads(dump)
+
+    assert atom1_reconst == atom1
+
 
 def test_typed_symbol():
     x = TypedSymbol("x", "uint32")
@@ -46,12 +54,34 @@ def test_typed_symbol():
     assert not z.is_nonnegative
 
 
-def test_cast_func():
-    assert (
-        CastFunc(TypedSymbol("s", np.uint), np.int64).canonical
-        == TypedSymbol("s", np.uint).canonical
-    )
-
-    a = CastFunc(5, np.uint)
-    assert a.is_negative is False
-    assert a.is_nonnegative
+def test_casts():
+    x, y = sp.symbols("x, y")
+
+    #   Pickling
+    expr = tcast(x, "int32")
+    dump = pickle.dumps(expr)
+    expr_reconst = pickle.loads(dump)
+    assert expr_reconst == expr
+
+    #   Boolean Casts
+    bool_expr = tcast(x, "bool")
+    assert isinstance(bool_expr, boolalg.Boolean)
+    
+    #   Check that we can construct boolean expressions with cast results
+    _ = boolalg.Or(bool_expr, y)
+    
+    #   Assumptions
+    expr = tcast(x, "int32")
+    assert expr.is_integer
+    assert expr.is_real
+    assert expr.is_nonnegative is None
+
+    expr = tcast(x, "uint32")
+    assert expr.is_integer
+    assert expr.is_real
+    assert expr.is_nonnegative
+
+    expr = tcast(x, "float32")
+    assert expr.is_integer is None
+    assert expr.is_real
+    assert expr.is_nonnegative is None
diff --git a/tests/kernelcreation/test_spatial_counters.py b/tests/kernelcreation/test_spatial_counters.py
index fdb365294c98311943c370cb650694b1a4bd8613..4f865ad97f42f31133cc5d0dc3fbba569f6f577d 100644
--- a/tests/kernelcreation/test_spatial_counters.py
+++ b/tests/kernelcreation/test_spatial_counters.py
@@ -9,7 +9,7 @@ from pystencils import (
     DEFAULTS,
     FieldType,
 )
-from pystencils.sympyextensions import CastFunc
+from pystencils.sympyextensions import tcast
 
 
 @pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"])
@@ -21,9 +21,9 @@ def test_spatial_counters_dense(index_dtype):
     f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
 
     asms = [
-        Assignment(f(0), CastFunc.as_numeric(z)),
-        Assignment(f(1), CastFunc.as_numeric(y)),
-        Assignment(f(2), CastFunc.as_numeric(x)),
+        Assignment(f(0), tcast.as_numeric(z)),
+        Assignment(f(1), tcast.as_numeric(y)),
+        Assignment(f(2), tcast.as_numeric(x)),
     ]
 
     cfg = CreateKernelConfig(index_dtype=index_dtype)
@@ -44,9 +44,9 @@ def test_spatial_counters_sparse(index_dtype):
     f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
 
     asms = [
-        Assignment(f(0), CastFunc.as_numeric(x)),
-        Assignment(f(1), CastFunc.as_numeric(y)),
-        Assignment(f(2), CastFunc.as_numeric(z)),
+        Assignment(f(0), tcast.as_numeric(x)),
+        Assignment(f(1), tcast.as_numeric(y)),
+        Assignment(f(2), tcast.as_numeric(z)),
     ]
 
     idx_struct = DEFAULTS.index_struct(index_dtype, 3)
diff --git a/tests/kernelcreation/test_type_cast.py b/tests/kernelcreation/test_type_cast.py
index 8ad6d867042ff6e57e5baee8c12ff45bae17e8e4..6b7acbbedbe395f36f306b76eeb09b8edc7444d9 100644
--- a/tests/kernelcreation/test_type_cast.py
+++ b/tests/kernelcreation/test_type_cast.py
@@ -8,7 +8,7 @@ from pystencils import (
     Assignment,
     Field,
 )
-from pystencils.sympyextensions.typed_sympy import CastFunc
+from pystencils.sympyextensions.typed_sympy import tcast
 
 
 AVAIL_TARGETS_NO_SSE = [t for t in Target.available_targets() if Target._SSE not in t]
@@ -55,7 +55,7 @@ def test_type_cast(gen_config, xp, from_type, to_type):
     inp_field = Field.create_from_numpy_array("inp", inp)
     outp_field = Field.create_from_numpy_array("outp", outp)
 
-    asms = [Assignment(outp_field.center(), CastFunc(inp_field.center(), to_type))]
+    asms = [Assignment(outp_field.center(), tcast(inp_field.center(), to_type))]
 
     kernel = create_kernel(asms, gen_config)
     kfunc = kernel.compile()
diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py
index fccec711423699c3d413682c3a5e8a99d5e092f1..f6c8f85b2b3df2289e809728b9e7b014d6428976 100644
--- a/tests/nbackend/kernelcreation/test_freeze.py
+++ b/tests/nbackend/kernelcreation/test_freeze.py
@@ -9,7 +9,7 @@ from pystencils import (
     TypedSymbol,
     DynamicType,
 )
-from pystencils.sympyextensions import CastFunc
+from pystencils.sympyextensions import tcast
 from pystencils.sympyextensions.pointers import mem_acc
 
 from pystencils.backend.ast.structural import (
@@ -312,16 +312,16 @@ def test_cast_func():
     y2 = PsExpression.make(ctx.get_symbol("y"))
     z2 = PsExpression.make(ctx.get_symbol("z"))
 
-    expr = freeze(CastFunc(x, create_type("int")))
+    expr = freeze(tcast(x, create_type("int")))
     assert expr.structurally_equal(PsCast(create_type("int"), x2))
 
-    expr = freeze(CastFunc.as_numeric(y))
+    expr = freeze(tcast.as_numeric(y))
     assert expr.structurally_equal(PsCast(ctx.default_dtype, y2))
 
-    expr = freeze(CastFunc.as_index(z))
+    expr = freeze(tcast.as_index(z))
     assert expr.structurally_equal(PsCast(ctx.index_dtype, z2))
 
-    expr = freeze(CastFunc(42, create_type("int16")))
+    expr = freeze(tcast(42, create_type("int16")))
     assert expr.structurally_equal(PsConstantExpr(PsConstant(42, create_type("int16"))))
 
 
diff --git a/tests/nbackend/transformations/test_ast_vectorizer.py b/tests/nbackend/transformations/test_ast_vectorizer.py
index f92f1e768a6e04c1eb5292612d6406365520bb72..3ccb479e5552bcd02954b9ed8518ef3ad0f90bfb 100644
--- a/tests/nbackend/transformations/test_ast_vectorizer.py
+++ b/tests/nbackend/transformations/test_ast_vectorizer.py
@@ -2,7 +2,7 @@ import sympy as sp
 import pytest
 
 from pystencils import Assignment, TypedSymbol, fields, FieldType, make_slice
-from pystencils.sympyextensions import CastFunc, mem_acc
+from pystencils.sympyextensions import tcast, mem_acc
 from pystencils.sympyextensions.pointers import AddressOf
 
 from pystencils.backend.constants import PsConstant
@@ -109,7 +109,7 @@ def test_vectorize_casts_and_counter():
     axis = VectorizationAxis(ctr, vec_ctr)
     vc = VectorizationContext(ctx, 4, axis)
 
-    expr = factory.parse_sympy(CastFunc(sp.Symbol("ctr"), create_type("float32")))
+    expr = factory.parse_sympy(tcast(sp.Symbol("ctr"), create_type("float32")))
     vec_expr = vectorize.visit(expr, vc)
 
     assert isinstance(vec_expr, PsCast)
@@ -136,7 +136,7 @@ def test_invalid_vectorization():
     axis = VectorizationAxis(ctr)
     vc = VectorizationContext(ctx, 4, axis)
 
-    expr = factory.parse_sympy(CastFunc(sp.Symbol("ctr"), create_type("float32")))
+    expr = factory.parse_sympy(tcast(sp.Symbol("ctr"), create_type("float32")))
 
     with pytest.raises(VectorizationError):
         #   Fails since no vectorized counter was specified
@@ -177,7 +177,7 @@ def test_vectorize_declarations():
         [
             factory.parse_sympy(asm)
             for asm in [
-                Assignment(x, CastFunc.as_numeric(ctr)),
+                Assignment(x, tcast.as_numeric(ctr)),
                 Assignment(y, sp.cos(x)),
                 Assignment(z, x**2 + 2 * y / 4),
                 Assignment(w, -x + y - z),
diff --git a/tests/nbackend/transformations/test_canonicalize_symbols.py b/tests/nbackend/transformations/test_canonicalize_symbols.py
index 2758d123417eb8e3015ed1d6b4d8cf0ba7c14611..dbc4ba10b71668af43d3a352d9cc49a8c9d61140 100644
--- a/tests/nbackend/transformations/test_canonicalize_symbols.py
+++ b/tests/nbackend/transformations/test_canonicalize_symbols.py
@@ -17,7 +17,7 @@ def test_deduplication():
     factory = AstFactory(ctx)
     canonicalize = CanonicalizeSymbols(ctx)
 
-    f = Field.create_fixed_size("f", (5, 5), strides=(5, 1))
+    f = Field.create_fixed_size("f", (5, 5), memory_strides=(5, 1))
     x, y, z = sp.symbols("x, y, z")
 
     ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], f)
diff --git a/tests/nbackend/transformations/test_hoist_invariants.py b/tests/nbackend/transformations/test_hoist_invariants.py
index daa2760c0b376dc0bb1f2ca59703a15efc5c2312..1f27a5a4cd17d9b20e54b3c44d1e733f8374f947 100644
--- a/tests/nbackend/transformations/test_hoist_invariants.py
+++ b/tests/nbackend/transformations/test_hoist_invariants.py
@@ -33,7 +33,7 @@ def test_hoist_multiple_loops():
     canonicalize = CanonicalizeSymbols(ctx)
     hoist = HoistLoopInvariantDeclarations(ctx)
 
-    f = Field.create_fixed_size("f", (5, 5), strides=(5, 1))
+    f = Field.create_fixed_size("f", (5, 5), memory_strides=(5, 1))
     x, y, z = sp.symbols("x, y, z")
 
     ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], f)
diff --git a/tests/runtime/test_datahandling.py b/tests/runtime/test_datahandling.py
index 62ba64056ab6d4062e49b76376d4e3cf3560ccf2..9d7ff924e8d7eba9039f8f0796145bd7de116ef5 100644
--- a/tests/runtime/test_datahandling.py
+++ b/tests/runtime/test_datahandling.py
@@ -249,7 +249,7 @@ def test_add_arrays():
     dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=0, default_layout='numpy')
     x_, y_ = dh.add_arrays(field_description)
 
-    x, y = ps.fields(field_description + ': [3,4,5]')
+    x, y = ps.fields(field_description + ': float64[3,4,5]')
 
     assert x_ == x
     assert y_ == y