diff --git a/src/pystencilssfg/lang/cpp/std_mdspan.py b/src/pystencilssfg/lang/cpp/std_mdspan.py index 4f7616011339d6640412305493f064972c3f081f..34a444443fbee7da44e4fa62d34a7e434d4f25ee 100644 --- a/src/pystencilssfg/lang/cpp/std_mdspan.py +++ b/src/pystencilssfg/lang/cpp/std_mdspan.py @@ -11,7 +11,7 @@ from pystencils.types import ( from pystencilssfg.lang.expressions import AugExpr -from ...lang import SrcField, IFieldExtraction, cpptype, Ref, HeaderFile +from ...lang import SrcField, IFieldExtraction, cpptype, Ref, HeaderFile, ExprLike class StdMdspan(SrcField): @@ -42,9 +42,13 @@ class StdMdspan(SrcField): **Creation from pystencils fields** Using `from_field`, ``mdspan`` objects can be created directly from `Field <pystencils.Field>` instances. - The `extents`_ and `layout_policy`_ of the ``mdspan`` type will be inferred from the field. - Each fixed entry in the field's shape will become a fixed entry of the ``mdspan``'s extents. - The field's memory layout will be mapped onto a predefined ``LayoutPolicy`` according to this table: + The `extents`_ of the ``mdspan`` type will be inferred from the field; + each fixed entry in the field's shape will become a fixed entry of the ``mdspan``'s extents. + + The ``mdspan``'s `layout_policy`_ defaults to `std::layout_stride`_, + which might not be the optimal choice depending on the memory layout of your fields. + You may therefore override this by specifying the name of the desired layout policy. + To map pystencils field layout identifiers to layout policies, consult the following table: +------------------------+--------------------------+ | pystencils Layout Name | ``mdspan`` Layout Policy | @@ -61,8 +65,8 @@ class StdMdspan(SrcField): | ``"aos"`` | | +------------------------+--------------------------+ - The structure-of-arrays (or ZYXF) layout has no equivalent layout policy in the C++ standard, - so it can only be mapped onto ``layout_stride``, which models user-defined strides. + The array-of-structures (``"aos"``, ``"zyxf"``) layout has no equivalent layout policy in the C++ standard, + so it can only be mapped onto ``layout_stride``. .. _extents: https://en.cppreference.com/w/cpp/container/mdspan/extents .. _layout_policy: https://en.cppreference.com/w/cpp/named_req/LayoutMappingPolicy @@ -91,49 +95,68 @@ class StdMdspan(SrcField): self, T: UserTypeSpec, extents: tuple[int | str, ...], - extents_type: UserTypeSpec = PsUnsignedIntegerType(64), + index_type: UserTypeSpec = PsUnsignedIntegerType(64), layout_policy: str | None = None, ref: bool = False, const: bool = False, ): T = create_type(T) - extents_type_str = create_type(extents_type).c_string() + extents_type_str = create_type(index_type).c_string() extents_str = f"{self._namespace}::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >" - layout_policy = ( - f"{self._namespace}::layout_right" - if layout_policy is None - else layout_policy - ) + if layout_policy is None: + layout_policy = f"{self._namespace}::layout_stride" + elif layout_policy in ("layout_left", "layout_right", "layout_stride"): + layout_policy = f"{self._namespace}::{layout_policy}" - dtype = self._template(T=T, extents=extents_str, layout_policy=layout_policy, const=const) + dtype = self._template( + T=T, extents=extents_str, layout_policy=layout_policy, const=const + ) if ref: dtype = Ref(dtype) super().__init__(dtype) - self._extents = extents + self._extents_type = extents_str + self._layout_type = layout_policy self._dim = len(extents) + @property + def extents_type(self) -> str: + return self._extents_type + + @property + def layout_type(self) -> str: + return self._layout_type + + def extent(self, r: int | ExprLike) -> AugExpr: + return AugExpr.format("{}.extent({})", self, r) + + def stride(self, r: int | ExprLike) -> AugExpr: + return AugExpr.format("{}.stride({})", self, r) + + def data_handle(self) -> AugExpr: + return AugExpr.format("{}.data_handle()", self) + def get_extraction(self) -> IFieldExtraction: mdspan = self class Extraction(IFieldExtraction): def ptr(self) -> AugExpr: - return AugExpr.format("{}.data_handle()", mdspan) + return mdspan.data_handle() def size(self, coordinate: int) -> AugExpr | None: if coordinate > mdspan._dim: return None else: - return AugExpr.format("{}.extents().extent({})", mdspan, coordinate) + return mdspan.extent(coordinate) def stride(self, coordinate: int) -> AugExpr | None: if coordinate > mdspan._dim: return None else: - return AugExpr.format("{}.stride({})", mdspan, coordinate) + return mdspan.stride(coordinate) return Extraction() @@ -147,18 +170,6 @@ class StdMdspan(SrcField): const: bool = False, ): """Creates a `std::mdspan` instance for a given pystencils field.""" - from pystencils.field import layout_string_to_tuple - - if layout_policy is None: - if field.layout == layout_string_to_tuple("fzyx", field.spatial_dimensions): - # f is the rightmost extent, which is slowest with `layout_left` - layout_policy = f"{cls._namespace}::layout_left" - elif field.layout == layout_string_to_tuple("c", field.spatial_dimensions): - # f, as the rightmost extent, is the fastest - layout_policy = f"{cls._namespace}::layout_right" - else: - layout_policy = f"{cls._namespace}::layout_stride" - extents: list[str | int] = [] for s in field.spatial_shape: @@ -172,7 +183,7 @@ class StdMdspan(SrcField): return StdMdspan( field.dtype, tuple(extents), - extents_type=extents_type, + index_type=extents_type, layout_policy=layout_policy, ref=ref, const=const, diff --git a/tests/generator_scripts/index.yaml b/tests/generator_scripts/index.yaml index 91bb77b99bd6e1dc8b9f7acfea936a8d8f97fa60..5cf5a9b542335d718b6a30ee423882c413deed01 100644 --- a/tests/generator_scripts/index.yaml +++ b/tests/generator_scripts/index.yaml @@ -54,6 +54,10 @@ ScaleKernel: JacobiMdspan: StlContainers1D: +# std::mdspan + +MdSpanFixedShapeLayouts: + # SYCL SyclKernels: diff --git a/tests/generator_scripts/source/MdSpanFixedShape.harness.cpp b/tests/generator_scripts/source/MdSpanFixedShape.harness.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc979c83a9b86f8c00ea0d78220063f7036eae58 --- /dev/null +++ b/tests/generator_scripts/source/MdSpanFixedShape.harness.cpp @@ -0,0 +1,35 @@ +#include "MdSpanLayouts.hpp" + +#include <concepts> +#include <experimental/mdspan> + +namespace stdex = std::experimental; + +static_assert( std::is_same_v< gen::field_soa::layout_type, stdex::layout_left > ); +static_assert( std::is_same_v< gen::field_aos::layout_type, stdex::layout_stride > ); +static_assert( std::is_same_v< gen::field_c::layout_type, stdex::layout_right > ); + +static_assert( gen::field_soa::static_extent(0) == 17 ); +static_assert( gen::field_soa::static_extent(1) == 19 ); +static_assert( gen::field_soa::static_extent(2) == 32 ); +static_assert( gen::field_soa::static_extent(3) == 9 ); + +int main(void) { + gen::field_soa f_soa { nullptr }; + gen::checkLayoutSoa(f_soa); + + gen::field_aos::extents_type f_aos_extents { }; + std::array< uint64_t, 4 > strides_aos { + /* stride(x) */ f_aos_extents.extent(3), + /* stride(y) */ f_aos_extents.extent(3) * f_aos_extents.extent(0), + /* stride(z) */ f_aos_extents.extent(3) * f_aos_extents.extent(0) * f_aos_extents.extent(1), + /* stride(f) */ 1 + }; + + gen::field_aos::mapping_type f_aos_mapping { f_aos_extents, strides_aos }; + gen::field_aos f_aos { nullptr, f_aos_mapping }; + gen::checkLayoutAos(f_aos); + + gen::field_c f_c { nullptr }; + gen::checkLayoutC(f_c); +} diff --git a/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py new file mode 100644 index 0000000000000000000000000000000000000000..c89fe2455e3ff117596bdd63d538d56d2afcc3b5 --- /dev/null +++ b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py @@ -0,0 +1,55 @@ +import pystencils as ps +from pystencilssfg import SourceFileGenerator +from pystencilssfg.lang.cpp import std +from pystencilssfg.lang import strip_ptr_ref + +std.mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>") + +with SourceFileGenerator() as sfg: + sfg.namespace("gen") + sfg.include("<cassert>") + + def check_layout(field: ps.Field, mdspan: std.mdspan): + seq = [] + + for d in range(field.spatial_dimensions + field.index_dimensions): + seq += [ + sfg.expr( + 'assert({} == {} && "Shape mismatch at coordinate {}");', + mdspan.extent(d), + field.shape[d], + d, + ), + sfg.expr( + 'assert({} == {} && "Stride mismatch at coordinate {}");', + mdspan.stride(d), + field.strides[d], + d, + ), + ] + + return seq + + f_soa = ps.fields("f_soa(9): double[17, 19, 32]", layout="soa") + f_soa_mdspan = std.mdspan.from_field(f_soa, layout_policy="layout_left", ref=True) + + sfg.code(f"using field_soa = {strip_ptr_ref(f_soa_mdspan.dtype)};") + sfg.function("checkLayoutSoa")( + *check_layout(f_soa, f_soa_mdspan) + ) + + f_aos = ps.fields("f_aos(9): double[17, 19, 32]", layout="aos") + f_aos_mdspan = std.mdspan.from_field(f_aos, ref=True) + sfg.code(f"using field_aos = {strip_ptr_ref(f_aos_mdspan.dtype)};") + + sfg.function("checkLayoutAos")( + *check_layout(f_aos, f_aos_mdspan) + ) + + f_c = ps.fields("f_c(9): double[17, 19, 32]", layout="c") + f_c_mdspan = std.mdspan.from_field(f_c, layout_policy="layout_right", ref=True) + sfg.code(f"using field_c = {strip_ptr_ref(f_c_mdspan.dtype)};") + + sfg.function("checkLayoutC")( + *check_layout(f_c, f_c_mdspan) + )