from typing import cast from sympy import Symbol from pystencils import Field, DynamicType from pystencils.types import ( PsType, PsUnsignedIntegerType, UserTypeSpec, create_type, ) from pystencilssfg.lang.expressions import AugExpr from ...lang import SrcField, IFieldExtraction, cpptype, HeaderFile, ExprLike class StdMdspan(SrcField): """Represents an `std::mdspan` instance. The `std::mdspan <https://en.cppreference.com/w/cpp/container/mdspan>`_ provides non-owning views into contiguous or strided n-dimensional arrays. It has been added to the C++ STL with the C++23 standard. As such, it is a natural data structure to target with pystencils kernels. **Concerning Headers and Namespaces** Since ``std::mdspan`` is not yet widely adopted (libc++ ships it as of LLVM 18, but GCC libstdc++ does not include it yet), you might have to manually include an implementation in your project (you can get a reference implementation at https://github.com/kokkos/mdspan). However, when working with a non-standard mdspan implementation, the path to its the header and the namespace it is defined in will likely be different. To tell pystencils-sfg which headers to include and which namespace to use for ``mdspan``, use `StdMdspan.configure`; for instance, adding this call before creating any ``mdspan`` objects will set their namespace to `std::experimental`, and require ``<experimental/mdspan>`` to be imported: >>> from pystencilssfg.lang.cpp import std >>> std.mdspan.configure("std::experimental", "<experimental/mdspan>") **Creation from pystencils fields** Using `from_field`, ``mdspan`` objects can be created directly from `Field <pystencils.Field>` instances. The `extents`_ of the ``mdspan`` type will be inferred from the field; each fixed entry in the field's shape will become a fixed entry of the ``mdspan``'s extents. The ``mdspan``'s `layout_policy`_ defaults to `std::layout_stride`_, which might not be the optimal choice depending on the memory layout of your fields. You may therefore override this by specifying the name of the desired layout policy. To map pystencils field layout identifiers to layout policies, consult the following table: +------------------------+--------------------------+ | pystencils Layout Name | ``mdspan`` Layout Policy | +========================+==========================+ | ``"fzyx"`` | `std::layout_left`_ | | ``"soa"`` | | | ``"f"`` | | | ``"reverse_numpy"`` | | +------------------------+--------------------------+ | ``"c"`` | `std::layout_right`_ | | ``"numpy"`` | | +------------------------+--------------------------+ | ``"zyxf"`` | `std::layout_stride`_ | | ``"aos"`` | | +------------------------+--------------------------+ The array-of-structures (``"aos"``, ``"zyxf"``) layout has no equivalent layout policy in the C++ standard, so it can only be mapped onto ``layout_stride``. .. _extents: https://en.cppreference.com/w/cpp/container/mdspan/extents .. _layout_policy: https://en.cppreference.com/w/cpp/named_req/LayoutMappingPolicy .. _std::layout_left: https://en.cppreference.com/w/cpp/container/mdspan/layout_left .. _std::layout_right: https://en.cppreference.com/w/cpp/container/mdspan/layout_right .. _std::layout_stride: https://en.cppreference.com/w/cpp/container/mdspan/layout_stride Args: T: Element type of the mdspan """ dynamic_extent = "std::dynamic_extent" _namespace = "std" _template = cpptype("std::mdspan< {T}, {extents}, {layout_policy} >", "<mdspan>") @classmethod def configure(cls, namespace: str = "std", header: str | HeaderFile = "<mdspan>"): """Configure the namespace and header ``std::mdspan`` is defined in.""" cls._namespace = namespace cls._template = cpptype( f"{namespace}::mdspan< {{T}}, {{extents}}, {{layout_policy}} >", header ) def __init__( self, T: UserTypeSpec, extents: tuple[int | str, ...], index_type: UserTypeSpec = PsUnsignedIntegerType(64), layout_policy: str | None = None, ref: bool = False, const: bool = False, ): T = create_type(T) extents_type_str = create_type(index_type).c_string() extents_str = f"{self._namespace}::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >" if layout_policy is None: layout_policy = f"{self._namespace}::layout_stride" elif layout_policy in ("layout_left", "layout_right", "layout_stride"): layout_policy = f"{self._namespace}::{layout_policy}" dtype = self._template( T=T, extents=extents_str, layout_policy=layout_policy, const=const, ref=ref ) super().__init__(dtype) self._element_type = T self._extents_type = extents_str self._layout_type = layout_policy self._dim = len(extents) @property def element_type(self) -> PsType: return self._element_type @property def extents_type(self) -> str: return self._extents_type @property def layout_type(self) -> str: return self._layout_type def extent(self, r: int | ExprLike) -> AugExpr: return AugExpr.format("{}.extent({})", self, r) def stride(self, r: int | ExprLike) -> AugExpr: return AugExpr.format("{}.stride({})", self, r) def data_handle(self) -> AugExpr: return AugExpr.format("{}.data_handle()", self) def get_extraction(self) -> IFieldExtraction: mdspan = self class Extraction(IFieldExtraction): def ptr(self) -> AugExpr: return mdspan.data_handle() def size(self, coordinate: int) -> AugExpr | None: if coordinate > mdspan._dim: return None else: return mdspan.extent(coordinate) def stride(self, coordinate: int) -> AugExpr | None: if coordinate > mdspan._dim: return None else: return mdspan.stride(coordinate) return Extraction() @staticmethod def from_field( field: Field, extents_type: UserTypeSpec = PsUnsignedIntegerType(64), layout_policy: str | None = None, ref: bool = False, const: bool = False, ): """Creates a `std::mdspan` instance for a given pystencils field.""" if isinstance(field.dtype, DynamicType): raise ValueError("Cannot map dynamically typed field to std::mdspan") extents: list[str | int] = [] for s in field.spatial_shape: extents.append( StdMdspan.dynamic_extent if isinstance(s, Symbol) else cast(int, s) ) for s in field.index_shape: extents.append(StdMdspan.dynamic_extent if isinstance(s, Symbol) else s) return StdMdspan( field.dtype, tuple(extents), index_type=extents_type, layout_policy=layout_policy, ref=ref, const=const, ).var(field.name) def mdspan_ref(field: Field, extents_type: PsType = PsUnsignedIntegerType(64)): from warnings import warn warn( "`mdspan_ref` is deprecated and will be removed in version 0.1. Use `std.mdspan.from_field` instead.", FutureWarning, ) return StdMdspan.from_field(field, extents_type, ref=True)