From ac0dfad306eebc71443d9973be6aeb0cfd5011d9 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 24 Jan 2025 11:39:42 +0100
Subject: [PATCH] extend dyn-type support for fields. Extend user guide on
 typing.

---
 docs/source/api/sympyextensions.rst       |   3 +
 docs/source/reference/WorkingWithTypes.md | 155 ++++++++++++++++------
 src/pystencils/field.py                   |  31 ++++-
 tests/frontend/test_field.py              |   6 +
 4 files changed, 154 insertions(+), 41 deletions(-)

diff --git a/docs/source/api/sympyextensions.rst b/docs/source/api/sympyextensions.rst
index 6bd8fc6ee..cbf94f8c5 100644
--- a/docs/source/api/sympyextensions.rst
+++ b/docs/source/api/sympyextensions.rst
@@ -71,6 +71,9 @@ Typed Expressions
 .. autoclass:: pystencils.DynamicType
     :members:
 
+.. autoclass:: pystencils.sympyextensions.typed_sympy.TypeCast
+    :members:
+
 .. autoclass:: pystencils.sympyextensions.tcast
 
 
diff --git a/docs/source/reference/WorkingWithTypes.md b/docs/source/reference/WorkingWithTypes.md
index 707553296..0fc79ecdb 100644
--- a/docs/source/reference/WorkingWithTypes.md
+++ b/docs/source/reference/WorkingWithTypes.md
@@ -1,7 +1,11 @@
 ---
 file_format: mystnb
 kernelspec:
-    name: python3
+  display_name: Python 3 (ipykernel)
+  language: python
+  name: python3
+mystnb:
+  execution_mode: cache
 ---
 
 # Working with Data Types
@@ -13,54 +17,129 @@ Individual fields and symbols,
 single subexpressions,
 or the entire kernel.
 
-```{code-cell}
+```{code-cell} ipython3
 :tags: [remove-cell]
 import pystencils as ps
+import sympy as sp
 ```
 
-## Understanding the pystencils Type Inference Algorithm
+## Setting the Types of Fields and Symbols
 
-To correctly apply varying data types to pystencils kernels, it is important to understand
-how pystencils computes and propagates the data types of expressions.
-These are the rules by which untyped symbols learn their data type:
-
- - All *free symbols* (that is, symbols not defined by an assignment in the kernel) receive the
-  *default data type* from the code generator configuration.
- - All symbols defined using a *constant expression* also receive the default data type.
- - All other symbols receive the data type computed for the right-hand side expression of their
-   defining assignment.
-
-To determine the type of right hand-side expressions, pystencils looks for any subexpressions
-nested inside them which already have a known type.
-These might be symbols whose type was already determined,
-or expressions with a fixed type, such as field accesses (which get the element type of their field),
-explicitly typed symbols, or type casts (see [](explicit_expression_types)).
-
-:::{attention}
-Expressions must always have a unique and unambiguous data type,
-and pystencils will not introduce any implicit casts.
-If pystencils finds subexpressions with conflicting types inside one expression,
-type inference will fail and `create_kernel` will raise an error.
-:::
+### Untyped Symbols
+
+Symbols used inside a kernel are most commonly created using
+{any}`sp.symbols <sympy.core.symbol.symbols>` or
+{any}`sp.Symbol <sympy.core.symbol.Symbol>`.
+These symbols are *untyped*; they will receive a type during code generation
+according to these rules:
+ - Free untyped symbols (i.e. symbols not defined by an assignment inside the kernel) receive the 
+   {any}`default data type <CreateKernelConfig.default_dtype>` specified in the code generator configuration.
+ - Bound untyped symbols (i.e. symbols that *are* defined in an assignment)
+   receive the data type that was computed for the right-hand side expression of their defining assignment.
 
-Through these rules, there are multiple ways to modify the data types used inside a kernel.
-These are highlighted in the following sections.
+If you are working on kernels with homogenous data types, using untyped symbols will mostly be enough.
 
-## Changing the Default Data Type
+### Explicitly Typed Symbols and Fields
 
-The *default data type* is the fallback type assigned by pystencils to all free symbols and symbols with constant
-definitions.
-It can be modified by setting the {any}`default_dtype <CreateKernelConfig.default_dtype>` option
-of the code generator configuration:
+If you need more control over the data types in (parts of) your kernel,
+you will have to explicitly specify them.
+To set an explicit data type for a symbol, use the {any}`TypedSymbol` class of pystencils:
 
 ```{code-cell} ipython3
-cfg = ps.CreateKernelConfig()
-cfg.default_dtype = "float32"
+x_typed = ps.TypedSymbol("x", "uint32")
+x_typed, str(x_typed.dtype)
 ```
 
-:::{admonition} Developers To Do
-Fields should use DynamicType by default!
-:::
+You can set a `TypedSymbol` to any data type provided by [the type system](#page_type_system),
+which will then be enforced by the code generator.
+
+The same holds for fields:
+When creating fields through the {any}`fields <pystencils.field.fields>` function,
+add the type to the descriptor string; for instance:
+
+```{code-cell} ipython3
+f, g = ps.fields("f(1), g(3): float32[3D]")
+str(f.dtype), str(g.dtype)
+```
+
+When using `Field.create_generic` or `Field.create_fixed_size`, on the other hand,
+you can set the data type via the `dtype` keyword argument.
+
+### Dynamically Typed Symbols and Fields
+
+Apart from explicitly setting data types,
+`TypedSymbol`s and fields can also receive a *dynamic data type* (see {any}`DynamicType`).
+There are two options:
+ - Symbols or fields annotated with {any}`DynamicType.NUMERIC_TYPE` will always receive
+   the {any}`default numeric type <CreateKernelConfig.default_dtype>` configured for the
+   code generator.
+   This is the default setting for fields
+   created through `fields`, `Field.create_generic` or `Field.create_fixed_size`.
+ - When annotated with {any}`DynamicType.INDEX_TYPE`, on the other hand, they will receive
+   the {any}`index data type <CreateKernelConfig.index_dtype>` configured for the kernel.
+
+Using dynamic typing, you can enforce symbols to receive either the standard numeric or
+index type without explicitly stating it, such that your kernel definition becomes
+independent from the code generator configuration.
+
+## Mixing Types Inside Expressions
+
+Pystencils enforces that all symbols, constants, and fields occuring inside an expression
+have the same data type.
+The code generator will never introduce implicit casts --
+if any type conflicts arise, it will terminate with an error.
+
+Still, there are cases where you want to combine subexpressions of different types;
+maybe you need to compute geometric information from loop counters or other integers,
+or you are doing mixed-precision numerical computations.
+In these cases, you might have to
+ 1. Introduce explicit type casts when values move from one type context to another;
+ 2. Annotate expressions with a specific data type to ensure computations are performed in that type.
+
+(type_casts)=
+### Type Casts
+
+Type casts can be introduced into kernels using the {any}`tcast` symbolic function.
+It takes an expression and a data type, which is either an explicit type (see [the type system](#page_type_system))
+or a dynamic type ({any}`DynamicType`):
+
+```{code-cell} ipython3
+x, y = sp.symbols("x, y")
+expr1 = ps.tcast(x, "float32")
+expr2 = ps.tcast(3 + y, ps.DynamicType.INDEX_TYPE)
+
+str(expr1.dtype), str(expr2.dtype)
+```
+
+When a type cast occurs, pystencils will compute the type of its argument independently
+and then introduce a runtime cast to the target type.
+That target type must comply with the type computed for the outer expression,
+which the cast is embedded in.
 
 (explicit_expression_types)=
-## Setting Explicit Types for Expressions
+### Setting Explicit Types for Expressions
+
+
+
+## Understanding the pystencils Type Inference System
+
+To correctly apply varying data types to pystencils kernels, it is important to understand
+how pystencils computes and propagates the data types of symbols and expressions.
+
+Type inference happens on the level of assignments.
+For each assignment $x := \mathrm{calc}(y_1, \dots, y_n)$,
+the system first attempts to compute a *unique* type for the right-hand side (RHS) $\mathrm{calc}(y_1, \dots, y_n)$.
+It searches for any subexpression inside the RHS for which a type is already known --
+these might be typed symbols
+(whose types are either set explicitly by the user,
+or have been determined from their defining assignment),
+field accesses,
+or explicitly typed expressions.
+It then attempts to apply that data type to the entire expression.
+If type conflicts occur, the process fails and the code generator raises an error.
+Otherwise, the resulting type is assigned to the left-hand side symbol $x$.
+
+:::{admonition} Developer's To Do
+It would be great to illustrate this using a GraphViz-plot of an AST,
+with nodes colored according to their data types
+:::
diff --git a/src/pystencils/field.py b/src/pystencils/field.py
index 5d7cb95cc..c64e0afad 100644
--- a/src/pystencils/field.py
+++ b/src/pystencils/field.py
@@ -1123,11 +1123,30 @@ def fields(
     layout=None,
     field_type=FieldType.GENERIC,
     **kwargs,
-) -> Union[Field, List[Field]]:
+) -> Field | list[Field]:
     """Creates pystencils fields from a string description.
 
+    The description must be a string of the form
+    ``"name(index-shape) [name(index-shape) ...] <data-type>[<dimension-or-shape>]"``,
+    where:
+
+    - ``name`` is the name of the respective field
+    - ``(index-shape)`` is a tuple of integers describing the shape of the tensor on each field node
+      (can be omitted for scalar fields)
+    - ``<data-type>`` is the numerical data type of the field's entries;
+      this can be any type parseable by `create_type`,
+      as well as ``dyn`` for `DynamicType.NUMERIC_TYPE`
+      and ``dynidx`` for `DynamicType.INDEX_TYPE`.
+    - ``<dimension-or-shape>`` can be a dimensionality (e.g. ``1D``, ``2D``, ``3D``)
+      or a tuple of integers defining the spatial shape of the field.
+
     Examples:
-        Create a 2D scalar and vector field:
+        Create a 3D scalar field of default numeric type:
+            >>> f = fields("f(1): [2D]")
+            >>> str(f.dtype)
+            'DynamicType.NUMERIC_TYPE'
+
+        Create a 2D scalar and vector field of 64-bit float type:
             >>> s, v = fields("s, v(2): double[2D]")
             >>> assert s.spatial_dimensions == 2 and s.index_dimensions == 0
             >>> assert (v.spatial_dimensions, v.index_dimensions, v.index_shape) == (2, 1, (2,))
@@ -1153,6 +1172,9 @@ def fields(
             >>> f = fields("pdfs(19) : float32[3D]", layout='fzyx')
             >>> f.layout
             (2, 1, 0)
+
+    Returns:
+        Sequence of fields created from the description
     """
     result = []
     if description:
@@ -1429,7 +1451,10 @@ def _parse_description(description):
                 data_type_str = data_type_str.lower().strip()
 
             if data_type_str:
-                dtype = create_type(data_type_str)
+                match data_type_str:
+                    case "dyn": dtype = DynamicType.NUMERIC_TYPE
+                    case "dynidx": dtype = DynamicType.INDEX_TYPE
+                    case _: dtype = create_type(data_type_str)
             else:
                 dtype = DynamicType.NUMERIC_TYPE
 
diff --git a/tests/frontend/test_field.py b/tests/frontend/test_field.py
index 6ac76f3c6..6d2942569 100644
--- a/tests/frontend/test_field.py
+++ b/tests/frontend/test_field.py
@@ -84,6 +84,12 @@ def test_field_description_parsing():
     assert f.index_shape == (1,)
     assert g.index_shape == (3,)
 
+    f = fields("f: dyn[3D]")
+    assert f.dtype == DynamicType.NUMERIC_TYPE
+
+    idx = fields("idx: dynidx[3D]")
+    assert idx.dtype == DynamicType.INDEX_TYPE
+
     h = fields("h: float32[3D]")
 
     assert h.index_shape == ()
-- 
GitLab