Skip to content
Snippets Groups Projects
Commit 5bd0cafe authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fix mdspan layout policies

parent af3aaea8
1 merge request!5Extend mdspan interface and fix mdspan memory layout mapping
......@@ -33,19 +33,22 @@ class StdMdspan(SrcField):
dynamic_extent = "std::dynamic_extent"
_namespace = "std"
_template = cpptype("std::mdspan< {T}, {extents} >", "<mdspan>")
_template = cpptype("std::mdspan< {T}, {extents}, {layout_policy} >", "<mdspan>")
@classmethod
def configure(cls, namespace: str = "std", header: str | HeaderFile = "<mdspan>"):
"""Configure the namespace and header `mdspan` is defined in."""
cls._namespace = namespace
cls._template = cpptype(f"{namespace}::mdspan< {{T}}, {{extents}} >", header)
cls._template = cpptype(
f"{namespace}::mdspan< {{T}}, {{extents}}, {{layout_policy}} >", header
)
def __init__(
self,
T: UserTypeSpec,
extents: tuple[int | str, ...],
extents_type: PsType = PsUnsignedIntegerType(64),
layout_policy: str | None = None,
ref: bool = False,
const: bool = False,
):
......@@ -54,7 +57,13 @@ class StdMdspan(SrcField):
extents_type_str = extents_type.c_string()
extents_str = f"{self._namespace}::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >"
dtype = self._template(T=T, extents=extents_str, const=const)
layout_policy = (
f"{self._namespace}::layout_right"
if layout_policy is None
else layout_policy
)
dtype = self._template(T=T, extents=extents_str, layout_policy=layout_policy, const=const)
if ref:
dtype = Ref(dtype)
......@@ -84,8 +93,9 @@ class StdMdspan(SrcField):
return Extraction()
@staticmethod
@classmethod
def from_field(
cls,
field: Field,
extents_type: PsType = PsUnsignedIntegerType(64),
ref: bool = False,
......@@ -94,10 +104,16 @@ class StdMdspan(SrcField):
"""Creates a `std::mdspan` instance for a given pystencils field."""
from pystencils.field import layout_string_to_tuple
if field.layout != layout_string_to_tuple("soa", field.spatial_dimensions):
raise NotImplementedError(
"mdspan mapping is currently only available for structure-of-arrays fields"
)
layout_policy: str
if field.layout == layout_string_to_tuple("fzyx", field.spatial_dimensions):
# f is the rightmost extent, which is slowest with `layout_left`
layout_policy = f"{cls._namespace}::layout_left"
elif field.layout == layout_string_to_tuple("zyxf", field.spatial_dimensions):
# f, as the rightmost extent, is the fastest
layout_policy = f"{cls._namespace}::layout_right"
else:
layout_policy = f"{cls._namespace}::layout_stride"
extents: list[str | int] = []
......@@ -110,7 +126,12 @@ class StdMdspan(SrcField):
extents.append(StdMdspan.dynamic_extent if isinstance(s, Symbol) else s)
return StdMdspan(
field.dtype, tuple(extents), extents_type=extents_type, ref=ref, const=const
field.dtype,
tuple(extents),
extents_type=extents_type,
layout_policy=layout_policy,
ref=ref,
const=const,
).var(field.name)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment