From 4857c7903d9e9742110756d5ab1b68239d5fc15c Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 30 Jan 2025 08:09:54 +0100
Subject: [PATCH] Adapt SFG to field typing changes in pystencils

---
 .../guide_generator_scripts/05/kernels.py     |  4 ++--
 src/pystencilssfg/lang/cpp/std_mdspan.py      | 10 +++++++++-
 src/pystencilssfg/lang/cpp/std_span.py        |  5 ++++-
 src/pystencilssfg/lang/cpp/std_vector.py      |  9 +++++++--
 src/pystencilssfg/lang/cpp/sycl_accessor.py   |  5 ++++-
 tests/lang/test_cpp_stl_classes.py            | 19 +++++++++++++++++++
 6 files changed, 45 insertions(+), 7 deletions(-)

diff --git a/docs/source/usage/examples/guide_generator_scripts/05/kernels.py b/docs/source/usage/examples/guide_generator_scripts/05/kernels.py
index 27cebb8..9b04c5b 100644
--- a/docs/source/usage/examples/guide_generator_scripts/05/kernels.py
+++ b/docs/source/usage/examples/guide_generator_scripts/05/kernels.py
@@ -5,7 +5,7 @@ import sympy as sp
 
 with SourceFileGenerator() as sfg:
     #   Define a copy kernel
-    src, dst = ps.fields("src, dst: [1D]")
+    src, dst = ps.fields("src, dst: double[1D]")
     c = sp.Symbol("c")
 
     @ps.kernel
@@ -19,7 +19,7 @@ with SourceFileGenerator() as sfg:
     import pystencilssfg.lang.cpp.std as std
 
     sfg.include("<span>")
-    
+
     sfg.function("scale_kernel")(
         sfg.map_field(src, std.vector.from_field(src)),
         sfg.map_field(dst, std.span.from_field(dst)),
diff --git a/src/pystencilssfg/lang/cpp/std_mdspan.py b/src/pystencilssfg/lang/cpp/std_mdspan.py
index 6830885..5e552e9 100644
--- a/src/pystencilssfg/lang/cpp/std_mdspan.py
+++ b/src/pystencilssfg/lang/cpp/std_mdspan.py
@@ -1,7 +1,7 @@
 from typing import cast
 from sympy import Symbol
 
-from pystencils import Field
+from pystencils import Field, DynamicType
 from pystencils.types import (
     PsType,
     PsUnsignedIntegerType,
@@ -115,10 +115,15 @@ class StdMdspan(SrcField):
         )
         super().__init__(dtype)
 
+        self._element_type = T
         self._extents_type = extents_str
         self._layout_type = layout_policy
         self._dim = len(extents)
 
+    @property
+    def element_type(self) -> PsType:
+        return self._element_type
+
     @property
     def extents_type(self) -> str:
         return self._extents_type
@@ -166,6 +171,9 @@ class StdMdspan(SrcField):
         const: bool = False,
     ):
         """Creates a `std::mdspan` instance for a given pystencils field."""
+        if isinstance(field.dtype, DynamicType):
+            raise ValueError("Cannot map dynamically typed field to std::mdspan")
+
         extents: list[str | int] = []
 
         for s in field.spatial_shape:
diff --git a/src/pystencilssfg/lang/cpp/std_span.py b/src/pystencilssfg/lang/cpp/std_span.py
index f161f48..ea4b520 100644
--- a/src/pystencilssfg/lang/cpp/std_span.py
+++ b/src/pystencilssfg/lang/cpp/std_span.py
@@ -1,4 +1,4 @@
-from pystencils.field import Field
+from pystencils import Field, DynamicType
 from pystencils.types import UserTypeSpec, create_type, PsType
 
 from ...lang import SrcField, IFieldExtraction, AugExpr, cpptype
@@ -44,6 +44,9 @@ class StdSpan(SrcField):
             raise ValueError(
                 "Only one-dimensional fields with trivial index dimensions can be mapped onto `std::span`"
             )
+        if isinstance(field.dtype, DynamicType):
+            raise ValueError("Cannot map dynamically typed field to std::span")
+
         return StdSpan(field.dtype, ref=ref, const=const).var(field.name)
 
 
diff --git a/src/pystencilssfg/lang/cpp/std_vector.py b/src/pystencilssfg/lang/cpp/std_vector.py
index 7e9291e..7356f94 100644
--- a/src/pystencilssfg/lang/cpp/std_vector.py
+++ b/src/pystencilssfg/lang/cpp/std_vector.py
@@ -1,4 +1,4 @@
-from pystencils.field import Field
+from pystencils import Field, DynamicType
 from pystencils.types import UserTypeSpec, create_type, PsType
 
 from ...lang import SrcField, SrcVector, AugExpr, IFieldExtraction, cpptype
@@ -59,7 +59,12 @@ class StdVector(SrcVector, SrcField):
                 f"Cannot create std::vector from more-than-one-dimensional field {field}."
             )
 
-        return StdVector(field.dtype, unsafe=False, ref=ref, const=const).var(field.name)
+        if isinstance(field.dtype, DynamicType):
+            raise ValueError("Cannot map dynamically typed field to std::vector")
+
+        return StdVector(field.dtype, unsafe=False, ref=ref, const=const).var(
+            field.name
+        )
 
 
 def std_vector_ref(field: Field):
diff --git a/src/pystencilssfg/lang/cpp/sycl_accessor.py b/src/pystencilssfg/lang/cpp/sycl_accessor.py
index 4bcad56..0052302 100644
--- a/src/pystencilssfg/lang/cpp/sycl_accessor.py
+++ b/src/pystencilssfg/lang/cpp/sycl_accessor.py
@@ -1,6 +1,6 @@
 from ...lang import SrcField, IFieldExtraction
 
-from pystencils import Field
+from pystencils import Field, DynamicType
 from pystencils.types import UserTypeSpec, create_type
 
 from ...lang import AugExpr, cpptype
@@ -75,6 +75,9 @@ class SyclAccessor(SrcField):
     def from_field(field: Field, ref: bool = True):
         """Creates a `sycl::accessor &` for a given pystencils field."""
 
+        if isinstance(field.dtype, DynamicType):
+            raise ValueError("Cannot map dynamically typed field to sycl::accessor")
+
         return SyclAccessor(
             field.dtype,
             field.spatial_dimensions + field.index_dimensions,
diff --git a/tests/lang/test_cpp_stl_classes.py b/tests/lang/test_cpp_stl_classes.py
index e400114..47966e2 100644
--- a/tests/lang/test_cpp_stl_classes.py
+++ b/tests/lang/test_cpp_stl_classes.py
@@ -31,6 +31,17 @@ def test_stl_containers():
     assert includes(expr) == {HeaderFile.parse("<span>")}
 
 
+def test_mdspan_from_field():
+    f = ps.fields("f: float32[3D]")
+    f_mdspan = std.mdspan.from_field(f)
+
+    assert f_mdspan.element_type == ps.create_type("float32")
+
+    f = ps.fields("f: dyn[3D]")
+    with pytest.raises(ValueError):
+        f_mdspan = std.mdspan.from_field(f)
+
+
 def test_vector_from_field():
     f = ps.fields("f: float32[1D]")
     f_vec = std.vector.from_field(f)
@@ -52,6 +63,10 @@ def test_vector_from_field():
     with pytest.raises(ValueError):
         std.vector.from_field(f)
 
+    f = ps.fields("f(1): dyn[1D]")
+    with pytest.raises(ValueError):
+        std.vector.from_field(f)
+
 
 def test_span_from_field():
     f = ps.fields("f: float32[1D]")
@@ -73,3 +88,7 @@ def test_span_from_field():
     f = ps.fields("f(1): float32[2D]")
     with pytest.raises(ValueError):
         std.span.from_field(f)
+
+    f = ps.fields("f(1): dyn[1D]")
+    with pytest.raises(ValueError):
+        std.span.from_field(f)
-- 
GitLab