From 5bd0cafeba2cf181f7cc85e256584ed5eb9cd4ff Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 11 Dec 2024 15:40:34 +0100 Subject: [PATCH] fix mdspan layout policies --- src/pystencilssfg/lang/cpp/std_mdspan.py | 39 ++++++++++++++++++------ 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/src/pystencilssfg/lang/cpp/std_mdspan.py b/src/pystencilssfg/lang/cpp/std_mdspan.py index 2ede108..7c3fd35 100644 --- a/src/pystencilssfg/lang/cpp/std_mdspan.py +++ b/src/pystencilssfg/lang/cpp/std_mdspan.py @@ -33,19 +33,22 @@ class StdMdspan(SrcField): dynamic_extent = "std::dynamic_extent" _namespace = "std" - _template = cpptype("std::mdspan< {T}, {extents} >", "<mdspan>") + _template = cpptype("std::mdspan< {T}, {extents}, {layout_policy} >", "<mdspan>") @classmethod def configure(cls, namespace: str = "std", header: str | HeaderFile = "<mdspan>"): """Configure the namespace and header `mdspan` is defined in.""" cls._namespace = namespace - cls._template = cpptype(f"{namespace}::mdspan< {{T}}, {{extents}} >", header) + cls._template = cpptype( + f"{namespace}::mdspan< {{T}}, {{extents}}, {{layout_policy}} >", header + ) def __init__( self, T: UserTypeSpec, extents: tuple[int | str, ...], extents_type: PsType = PsUnsignedIntegerType(64), + layout_policy: str | None = None, ref: bool = False, const: bool = False, ): @@ -54,7 +57,13 @@ class StdMdspan(SrcField): extents_type_str = extents_type.c_string() extents_str = f"{self._namespace}::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >" - dtype = self._template(T=T, extents=extents_str, const=const) + layout_policy = ( + f"{self._namespace}::layout_right" + if layout_policy is None + else layout_policy + ) + + dtype = self._template(T=T, extents=extents_str, layout_policy=layout_policy, const=const) if ref: dtype = Ref(dtype) @@ -84,8 +93,9 @@ class StdMdspan(SrcField): return Extraction() - @staticmethod + @classmethod def from_field( + cls, field: Field, extents_type: PsType = PsUnsignedIntegerType(64), ref: bool = False, @@ -94,10 +104,16 @@ class StdMdspan(SrcField): """Creates a `std::mdspan` instance for a given pystencils field.""" from pystencils.field import layout_string_to_tuple - if field.layout != layout_string_to_tuple("soa", field.spatial_dimensions): - raise NotImplementedError( - "mdspan mapping is currently only available for structure-of-arrays fields" - ) + layout_policy: str + + if field.layout == layout_string_to_tuple("fzyx", field.spatial_dimensions): + # f is the rightmost extent, which is slowest with `layout_left` + layout_policy = f"{cls._namespace}::layout_left" + elif field.layout == layout_string_to_tuple("zyxf", field.spatial_dimensions): + # f, as the rightmost extent, is the fastest + layout_policy = f"{cls._namespace}::layout_right" + else: + layout_policy = f"{cls._namespace}::layout_stride" extents: list[str | int] = [] @@ -110,7 +126,12 @@ class StdMdspan(SrcField): extents.append(StdMdspan.dynamic_extent if isinstance(s, Symbol) else s) return StdMdspan( - field.dtype, tuple(extents), extents_type=extents_type, ref=ref, const=const + field.dtype, + tuple(extents), + extents_type=extents_type, + layout_policy=layout_policy, + ref=ref, + const=const, ).var(field.name) -- GitLab