From 5bd0cafeba2cf181f7cc85e256584ed5eb9cd4ff Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 11 Dec 2024 15:40:34 +0100
Subject: [PATCH] fix mdspan layout policies

---
 src/pystencilssfg/lang/cpp/std_mdspan.py | 39 ++++++++++++++++++------
 1 file changed, 30 insertions(+), 9 deletions(-)

diff --git a/src/pystencilssfg/lang/cpp/std_mdspan.py b/src/pystencilssfg/lang/cpp/std_mdspan.py
index 2ede108..7c3fd35 100644
--- a/src/pystencilssfg/lang/cpp/std_mdspan.py
+++ b/src/pystencilssfg/lang/cpp/std_mdspan.py
@@ -33,19 +33,22 @@ class StdMdspan(SrcField):
     dynamic_extent = "std::dynamic_extent"
 
     _namespace = "std"
-    _template = cpptype("std::mdspan< {T}, {extents} >", "<mdspan>")
+    _template = cpptype("std::mdspan< {T}, {extents}, {layout_policy} >", "<mdspan>")
 
     @classmethod
     def configure(cls, namespace: str = "std", header: str | HeaderFile = "<mdspan>"):
         """Configure the namespace and header `mdspan` is defined in."""
         cls._namespace = namespace
-        cls._template = cpptype(f"{namespace}::mdspan< {{T}}, {{extents}} >", header)
+        cls._template = cpptype(
+            f"{namespace}::mdspan< {{T}}, {{extents}}, {{layout_policy}} >", header
+        )
 
     def __init__(
         self,
         T: UserTypeSpec,
         extents: tuple[int | str, ...],
         extents_type: PsType = PsUnsignedIntegerType(64),
+        layout_policy: str | None = None,
         ref: bool = False,
         const: bool = False,
     ):
@@ -54,7 +57,13 @@ class StdMdspan(SrcField):
         extents_type_str = extents_type.c_string()
         extents_str = f"{self._namespace}::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >"
 
-        dtype = self._template(T=T, extents=extents_str, const=const)
+        layout_policy = (
+            f"{self._namespace}::layout_right"
+            if layout_policy is None
+            else layout_policy
+        )
+
+        dtype = self._template(T=T, extents=extents_str, layout_policy=layout_policy, const=const)
 
         if ref:
             dtype = Ref(dtype)
@@ -84,8 +93,9 @@ class StdMdspan(SrcField):
 
         return Extraction()
 
-    @staticmethod
+    @classmethod
     def from_field(
+        cls,
         field: Field,
         extents_type: PsType = PsUnsignedIntegerType(64),
         ref: bool = False,
@@ -94,10 +104,16 @@ class StdMdspan(SrcField):
         """Creates a `std::mdspan` instance for a given pystencils field."""
         from pystencils.field import layout_string_to_tuple
 
-        if field.layout != layout_string_to_tuple("soa", field.spatial_dimensions):
-            raise NotImplementedError(
-                "mdspan mapping is currently only available for structure-of-arrays fields"
-            )
+        layout_policy: str
+
+        if field.layout == layout_string_to_tuple("fzyx", field.spatial_dimensions):
+            #   f is the rightmost extent, which is slowest with `layout_left`
+            layout_policy = f"{cls._namespace}::layout_left"
+        elif field.layout == layout_string_to_tuple("zyxf", field.spatial_dimensions):
+            #   f, as the rightmost extent, is the fastest
+            layout_policy = f"{cls._namespace}::layout_right"
+        else:
+            layout_policy = f"{cls._namespace}::layout_stride"
 
         extents: list[str | int] = []
 
@@ -110,7 +126,12 @@ class StdMdspan(SrcField):
             extents.append(StdMdspan.dynamic_extent if isinstance(s, Symbol) else s)
 
         return StdMdspan(
-            field.dtype, tuple(extents), extents_type=extents_type, ref=ref, const=const
+            field.dtype,
+            tuple(extents),
+            extents_type=extents_type,
+            layout_policy=layout_policy,
+            ref=ref,
+            const=const,
         ).var(field.name)
 
 
-- 
GitLab