Skip to content
Snippets Groups Projects
Commit 5bd0cafe authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fix mdspan layout policies

parent af3aaea8
No related branches found
No related tags found
1 merge request!5Extend mdspan interface and fix mdspan memory layout mapping
......@@ -33,19 +33,22 @@ class StdMdspan(SrcField):
dynamic_extent = "std::dynamic_extent"
_namespace = "std"
_template = cpptype("std::mdspan< {T}, {extents} >", "<mdspan>")
_template = cpptype("std::mdspan< {T}, {extents}, {layout_policy} >", "<mdspan>")
@classmethod
def configure(cls, namespace: str = "std", header: str | HeaderFile = "<mdspan>"):
"""Configure the namespace and header `mdspan` is defined in."""
cls._namespace = namespace
cls._template = cpptype(f"{namespace}::mdspan< {{T}}, {{extents}} >", header)
cls._template = cpptype(
f"{namespace}::mdspan< {{T}}, {{extents}}, {{layout_policy}} >", header
)
def __init__(
self,
T: UserTypeSpec,
extents: tuple[int | str, ...],
extents_type: PsType = PsUnsignedIntegerType(64),
layout_policy: str | None = None,
ref: bool = False,
const: bool = False,
):
......@@ -54,7 +57,13 @@ class StdMdspan(SrcField):
extents_type_str = extents_type.c_string()
extents_str = f"{self._namespace}::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >"
dtype = self._template(T=T, extents=extents_str, const=const)
layout_policy = (
f"{self._namespace}::layout_right"
if layout_policy is None
else layout_policy
)
dtype = self._template(T=T, extents=extents_str, layout_policy=layout_policy, const=const)
if ref:
dtype = Ref(dtype)
......@@ -84,8 +93,9 @@ class StdMdspan(SrcField):
return Extraction()
@staticmethod
@classmethod
def from_field(
cls,
field: Field,
extents_type: PsType = PsUnsignedIntegerType(64),
ref: bool = False,
......@@ -94,10 +104,16 @@ class StdMdspan(SrcField):
"""Creates a `std::mdspan` instance for a given pystencils field."""
from pystencils.field import layout_string_to_tuple
if field.layout != layout_string_to_tuple("soa", field.spatial_dimensions):
raise NotImplementedError(
"mdspan mapping is currently only available for structure-of-arrays fields"
)
layout_policy: str
if field.layout == layout_string_to_tuple("fzyx", field.spatial_dimensions):
# f is the rightmost extent, which is slowest with `layout_left`
layout_policy = f"{cls._namespace}::layout_left"
elif field.layout == layout_string_to_tuple("zyxf", field.spatial_dimensions):
# f, as the rightmost extent, is the fastest
layout_policy = f"{cls._namespace}::layout_right"
else:
layout_policy = f"{cls._namespace}::layout_stride"
extents: list[str | int] = []
......@@ -110,7 +126,12 @@ class StdMdspan(SrcField):
extents.append(StdMdspan.dynamic_extent if isinstance(s, Symbol) else s)
return StdMdspan(
field.dtype, tuple(extents), extents_type=extents_type, ref=ref, const=const
field.dtype,
tuple(extents),
extents_type=extents_type,
layout_policy=layout_policy,
ref=ref,
const=const,
).var(field.name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment