Skip to content
Snippets Groups Projects
Commit 7d24e077 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

update layout_policy in mdspan mirror. Add test case for fixed-shape mdspans...

update layout_policy in mdspan mirror. Add test case for fixed-shape mdspans and their memory layouts.
parent 49b4c18d
No related branches found
No related tags found
1 merge request!5Extend mdspan interface and fix mdspan memory layout mapping
Pipeline #70955 passed
...@@ -11,7 +11,7 @@ from pystencils.types import ( ...@@ -11,7 +11,7 @@ from pystencils.types import (
from pystencilssfg.lang.expressions import AugExpr from pystencilssfg.lang.expressions import AugExpr
from ...lang import SrcField, IFieldExtraction, cpptype, Ref, HeaderFile from ...lang import SrcField, IFieldExtraction, cpptype, Ref, HeaderFile, ExprLike
class StdMdspan(SrcField): class StdMdspan(SrcField):
...@@ -42,9 +42,13 @@ class StdMdspan(SrcField): ...@@ -42,9 +42,13 @@ class StdMdspan(SrcField):
**Creation from pystencils fields** **Creation from pystencils fields**
Using `from_field`, ``mdspan`` objects can be created directly from `Field <pystencils.Field>` instances. Using `from_field`, ``mdspan`` objects can be created directly from `Field <pystencils.Field>` instances.
The `extents`_ and `layout_policy`_ of the ``mdspan`` type will be inferred from the field. The `extents`_ of the ``mdspan`` type will be inferred from the field;
Each fixed entry in the field's shape will become a fixed entry of the ``mdspan``'s extents. each fixed entry in the field's shape will become a fixed entry of the ``mdspan``'s extents.
The field's memory layout will be mapped onto a predefined ``LayoutPolicy`` according to this table:
The ``mdspan``'s `layout_policy`_ defaults to `std::layout_stride`_,
which might not be the optimal choice depending on the memory layout of your fields.
You may therefore override this by specifying the name of the desired layout policy.
To map pystencils field layout identifiers to layout policies, consult the following table:
+------------------------+--------------------------+ +------------------------+--------------------------+
| pystencils Layout Name | ``mdspan`` Layout Policy | | pystencils Layout Name | ``mdspan`` Layout Policy |
...@@ -61,8 +65,8 @@ class StdMdspan(SrcField): ...@@ -61,8 +65,8 @@ class StdMdspan(SrcField):
| ``"aos"`` | | | ``"aos"`` | |
+------------------------+--------------------------+ +------------------------+--------------------------+
The structure-of-arrays (or ZYXF) layout has no equivalent layout policy in the C++ standard, The array-of-structures (``"aos"``, ``"zyxf"``) layout has no equivalent layout policy in the C++ standard,
so it can only be mapped onto ``layout_stride``, which models user-defined strides. so it can only be mapped onto ``layout_stride``.
.. _extents: https://en.cppreference.com/w/cpp/container/mdspan/extents .. _extents: https://en.cppreference.com/w/cpp/container/mdspan/extents
.. _layout_policy: https://en.cppreference.com/w/cpp/named_req/LayoutMappingPolicy .. _layout_policy: https://en.cppreference.com/w/cpp/named_req/LayoutMappingPolicy
...@@ -91,49 +95,68 @@ class StdMdspan(SrcField): ...@@ -91,49 +95,68 @@ class StdMdspan(SrcField):
self, self,
T: UserTypeSpec, T: UserTypeSpec,
extents: tuple[int | str, ...], extents: tuple[int | str, ...],
extents_type: UserTypeSpec = PsUnsignedIntegerType(64), index_type: UserTypeSpec = PsUnsignedIntegerType(64),
layout_policy: str | None = None, layout_policy: str | None = None,
ref: bool = False, ref: bool = False,
const: bool = False, const: bool = False,
): ):
T = create_type(T) T = create_type(T)
extents_type_str = create_type(extents_type).c_string() extents_type_str = create_type(index_type).c_string()
extents_str = f"{self._namespace}::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >" extents_str = f"{self._namespace}::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >"
layout_policy = ( if layout_policy is None:
f"{self._namespace}::layout_right" layout_policy = f"{self._namespace}::layout_stride"
if layout_policy is None elif layout_policy in ("layout_left", "layout_right", "layout_stride"):
else layout_policy layout_policy = f"{self._namespace}::{layout_policy}"
)
dtype = self._template(T=T, extents=extents_str, layout_policy=layout_policy, const=const) dtype = self._template(
T=T, extents=extents_str, layout_policy=layout_policy, const=const
)
if ref: if ref:
dtype = Ref(dtype) dtype = Ref(dtype)
super().__init__(dtype) super().__init__(dtype)
self._extents = extents self._extents_type = extents_str
self._layout_type = layout_policy
self._dim = len(extents) self._dim = len(extents)
@property
def extents_type(self) -> str:
return self._extents_type
@property
def layout_type(self) -> str:
return self._layout_type
def extent(self, r: int | ExprLike) -> AugExpr:
return AugExpr.format("{}.extent({})", self, r)
def stride(self, r: int | ExprLike) -> AugExpr:
return AugExpr.format("{}.stride({})", self, r)
def data_handle(self) -> AugExpr:
return AugExpr.format("{}.data_handle()", self)
def get_extraction(self) -> IFieldExtraction: def get_extraction(self) -> IFieldExtraction:
mdspan = self mdspan = self
class Extraction(IFieldExtraction): class Extraction(IFieldExtraction):
def ptr(self) -> AugExpr: def ptr(self) -> AugExpr:
return AugExpr.format("{}.data_handle()", mdspan) return mdspan.data_handle()
def size(self, coordinate: int) -> AugExpr | None: def size(self, coordinate: int) -> AugExpr | None:
if coordinate > mdspan._dim: if coordinate > mdspan._dim:
return None return None
else: else:
return AugExpr.format("{}.extents().extent({})", mdspan, coordinate) return mdspan.extent(coordinate)
def stride(self, coordinate: int) -> AugExpr | None: def stride(self, coordinate: int) -> AugExpr | None:
if coordinate > mdspan._dim: if coordinate > mdspan._dim:
return None return None
else: else:
return AugExpr.format("{}.stride({})", mdspan, coordinate) return mdspan.stride(coordinate)
return Extraction() return Extraction()
...@@ -147,18 +170,6 @@ class StdMdspan(SrcField): ...@@ -147,18 +170,6 @@ class StdMdspan(SrcField):
const: bool = False, const: bool = False,
): ):
"""Creates a `std::mdspan` instance for a given pystencils field.""" """Creates a `std::mdspan` instance for a given pystencils field."""
from pystencils.field import layout_string_to_tuple
if layout_policy is None:
if field.layout == layout_string_to_tuple("fzyx", field.spatial_dimensions):
# f is the rightmost extent, which is slowest with `layout_left`
layout_policy = f"{cls._namespace}::layout_left"
elif field.layout == layout_string_to_tuple("c", field.spatial_dimensions):
# f, as the rightmost extent, is the fastest
layout_policy = f"{cls._namespace}::layout_right"
else:
layout_policy = f"{cls._namespace}::layout_stride"
extents: list[str | int] = [] extents: list[str | int] = []
for s in field.spatial_shape: for s in field.spatial_shape:
...@@ -172,7 +183,7 @@ class StdMdspan(SrcField): ...@@ -172,7 +183,7 @@ class StdMdspan(SrcField):
return StdMdspan( return StdMdspan(
field.dtype, field.dtype,
tuple(extents), tuple(extents),
extents_type=extents_type, index_type=extents_type,
layout_policy=layout_policy, layout_policy=layout_policy,
ref=ref, ref=ref,
const=const, const=const,
......
...@@ -54,6 +54,10 @@ ScaleKernel: ...@@ -54,6 +54,10 @@ ScaleKernel:
JacobiMdspan: JacobiMdspan:
StlContainers1D: StlContainers1D:
# std::mdspan
MdSpanFixedShapeLayouts:
# SYCL # SYCL
SyclKernels: SyclKernels:
......
#include "MdSpanLayouts.hpp"
#include <concepts>
#include <experimental/mdspan>
namespace stdex = std::experimental;
static_assert( std::is_same_v< gen::field_soa::layout_type, stdex::layout_left > );
static_assert( std::is_same_v< gen::field_aos::layout_type, stdex::layout_stride > );
static_assert( std::is_same_v< gen::field_c::layout_type, stdex::layout_right > );
static_assert( gen::field_soa::static_extent(0) == 17 );
static_assert( gen::field_soa::static_extent(1) == 19 );
static_assert( gen::field_soa::static_extent(2) == 32 );
static_assert( gen::field_soa::static_extent(3) == 9 );
int main(void) {
gen::field_soa f_soa { nullptr };
gen::checkLayoutSoa(f_soa);
gen::field_aos::extents_type f_aos_extents { };
std::array< uint64_t, 4 > strides_aos {
/* stride(x) */ f_aos_extents.extent(3),
/* stride(y) */ f_aos_extents.extent(3) * f_aos_extents.extent(0),
/* stride(z) */ f_aos_extents.extent(3) * f_aos_extents.extent(0) * f_aos_extents.extent(1),
/* stride(f) */ 1
};
gen::field_aos::mapping_type f_aos_mapping { f_aos_extents, strides_aos };
gen::field_aos f_aos { nullptr, f_aos_mapping };
gen::checkLayoutAos(f_aos);
gen::field_c f_c { nullptr };
gen::checkLayoutC(f_c);
}
import pystencils as ps
from pystencilssfg import SourceFileGenerator
from pystencilssfg.lang.cpp import std
from pystencilssfg.lang import strip_ptr_ref
std.mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>")
with SourceFileGenerator() as sfg:
sfg.namespace("gen")
sfg.include("<cassert>")
def check_layout(field: ps.Field, mdspan: std.mdspan):
seq = []
for d in range(field.spatial_dimensions + field.index_dimensions):
seq += [
sfg.expr(
'assert({} == {} && "Shape mismatch at coordinate {}");',
mdspan.extent(d),
field.shape[d],
d,
),
sfg.expr(
'assert({} == {} && "Stride mismatch at coordinate {}");',
mdspan.stride(d),
field.strides[d],
d,
),
]
return seq
f_soa = ps.fields("f_soa(9): double[17, 19, 32]", layout="soa")
f_soa_mdspan = std.mdspan.from_field(f_soa, layout_policy="layout_left", ref=True)
sfg.code(f"using field_soa = {strip_ptr_ref(f_soa_mdspan.dtype)};")
sfg.function("checkLayoutSoa")(
*check_layout(f_soa, f_soa_mdspan)
)
f_aos = ps.fields("f_aos(9): double[17, 19, 32]", layout="aos")
f_aos_mdspan = std.mdspan.from_field(f_aos, ref=True)
sfg.code(f"using field_aos = {strip_ptr_ref(f_aos_mdspan.dtype)};")
sfg.function("checkLayoutAos")(
*check_layout(f_aos, f_aos_mdspan)
)
f_c = ps.fields("f_c(9): double[17, 19, 32]", layout="c")
f_c_mdspan = std.mdspan.from_field(f_c, layout_policy="layout_right", ref=True)
sfg.code(f"using field_c = {strip_ptr_ref(f_c_mdspan.dtype)};")
sfg.function("checkLayoutC")(
*check_layout(f_c, f_c_mdspan)
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment