Skip to content
Snippets Groups Projects
std_mdspan.py 7.66 KiB
from typing import cast
from sympy import Symbol

from pystencils import Field, DynamicType
from pystencils.types import (
    PsType,
    PsUnsignedIntegerType,
    UserTypeSpec,
    create_type,
)

from pystencilssfg.lang.expressions import AugExpr

from ...lang import SrcField, IFieldExtraction, cpptype, HeaderFile, ExprLike


class StdMdspan(SrcField):
    """Represents an `std::mdspan` instance.

    The `std::mdspan <https://en.cppreference.com/w/cpp/container/mdspan>`_
    provides non-owning views into contiguous or strided n-dimensional arrays.
    It has been added to the C++ STL with the C++23 standard.
    As such, it is a natural data structure to target with pystencils kernels.

    **Concerning Headers and Namespaces**

    Since ``std::mdspan`` is not yet widely adopted
    (libc++ ships it as of LLVM 18, but GCC libstdc++ does not include it yet),
    you might have to manually include an implementation in your project
    (you can get a reference implementation at https://github.com/kokkos/mdspan).
    However, when working with a non-standard mdspan implementation,
    the path to its the header and the namespace it is defined in will likely be different.

    To tell pystencils-sfg which headers to include and which namespace to use for ``mdspan``,
    use `StdMdspan.configure`;
    for instance, adding this call before creating any ``mdspan`` objects will
    set their namespace to `std::experimental`, and require ``<experimental/mdspan>`` to be imported:

    >>> from pystencilssfg.lang.cpp import std
    >>> std.mdspan.configure("std::experimental", "<experimental/mdspan>")

    **Creation from pystencils fields**

    Using `from_field`, ``mdspan`` objects can be created directly from `Field <pystencils.Field>` instances.
    The `extents`_ of the ``mdspan`` type will be inferred from the field;
    each fixed entry in the field's shape will become a fixed entry of the ``mdspan``'s extents.

    The ``mdspan``'s `layout_policy`_ defaults to `std::layout_stride`_,
    which might not be the optimal choice depending on the memory layout of your fields.
    You may therefore override this by specifying the name of the desired layout policy.
    To map pystencils field layout identifiers to layout policies, consult the following table:

    +------------------------+--------------------------+
    | pystencils Layout Name | ``mdspan`` Layout Policy |
    +========================+==========================+
    | ``"fzyx"``             | `std::layout_left`_      |
    | ``"soa"``              |                          |
    | ``"f"``                |                          |
    | ``"reverse_numpy"``    |                          |
    +------------------------+--------------------------+
    | ``"c"``                | `std::layout_right`_     |
    | ``"numpy"``            |                          |
    +------------------------+--------------------------+
    | ``"zyxf"``             | `std::layout_stride`_    |
    | ``"aos"``              |                          |
    +------------------------+--------------------------+

    The array-of-structures (``"aos"``, ``"zyxf"``) layout has no equivalent layout policy in the C++ standard,
    so it can only be mapped onto ``layout_stride``.

    .. _extents: https://en.cppreference.com/w/cpp/container/mdspan/extents
    .. _layout_policy: https://en.cppreference.com/w/cpp/named_req/LayoutMappingPolicy
    .. _std::layout_left: https://en.cppreference.com/w/cpp/container/mdspan/layout_left
    .. _std::layout_right: https://en.cppreference.com/w/cpp/container/mdspan/layout_right
    .. _std::layout_stride: https://en.cppreference.com/w/cpp/container/mdspan/layout_stride

    Args:
        T: Element type of the mdspan
    """

    dynamic_extent = "std::dynamic_extent"

    _namespace = "std"
    _template = cpptype("std::mdspan< {T}, {extents}, {layout_policy} >", "<mdspan>")

    @classmethod
    def configure(cls, namespace: str = "std", header: str | HeaderFile = "<mdspan>"):
        """Configure the namespace and header ``std::mdspan`` is defined in."""
        cls._namespace = namespace
        cls._template = cpptype(
            f"{namespace}::mdspan< {{T}}, {{extents}}, {{layout_policy}} >", header
        )

    def __init__(
        self,
        T: UserTypeSpec,
        extents: tuple[int | str, ...],
        index_type: UserTypeSpec = PsUnsignedIntegerType(64),
        layout_policy: str | None = None,
        ref: bool = False,
        const: bool = False,
    ):
        T = create_type(T)

        extents_type_str = create_type(index_type).c_string()
        extents_str = f"{self._namespace}::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >"

        if layout_policy is None:
            layout_policy = f"{self._namespace}::layout_stride"
        elif layout_policy in ("layout_left", "layout_right", "layout_stride"):
            layout_policy = f"{self._namespace}::{layout_policy}"

        dtype = self._template(
            T=T, extents=extents_str, layout_policy=layout_policy, const=const, ref=ref
        )
        super().__init__(dtype)

        self._element_type = T
        self._extents_type = extents_str
        self._layout_type = layout_policy
        self._dim = len(extents)

    @property
    def element_type(self) -> PsType:
        return self._element_type

    @property
    def extents_type(self) -> str:
        return self._extents_type

    @property
    def layout_type(self) -> str:
        return self._layout_type

    def extent(self, r: int | ExprLike) -> AugExpr:
        return AugExpr.format("{}.extent({})", self, r)

    def stride(self, r: int | ExprLike) -> AugExpr:
        return AugExpr.format("{}.stride({})", self, r)

    def data_handle(self) -> AugExpr:
        return AugExpr.format("{}.data_handle()", self)

    def get_extraction(self) -> IFieldExtraction:
        mdspan = self

        class Extraction(IFieldExtraction):
            def ptr(self) -> AugExpr:
                return mdspan.data_handle()

            def size(self, coordinate: int) -> AugExpr | None:
                if coordinate > mdspan._dim:
                    return None
                else:
                    return mdspan.extent(coordinate)

            def stride(self, coordinate: int) -> AugExpr | None:
                if coordinate > mdspan._dim:
                    return None
                else:
                    return mdspan.stride(coordinate)

        return Extraction()

    @staticmethod
    def from_field(
        field: Field,
        extents_type: UserTypeSpec = PsUnsignedIntegerType(64),
        layout_policy: str | None = None,
        ref: bool = False,
        const: bool = False,
    ):
        """Creates a `std::mdspan` instance for a given pystencils field."""
        if isinstance(field.dtype, DynamicType):
            raise ValueError("Cannot map dynamically typed field to std::mdspan")

        extents: list[str | int] = []

        for s in field.spatial_shape:
            extents.append(
                StdMdspan.dynamic_extent if isinstance(s, Symbol) else cast(int, s)
            )

        for s in field.index_shape:
            extents.append(StdMdspan.dynamic_extent if isinstance(s, Symbol) else s)

        return StdMdspan(
            field.dtype,
            tuple(extents),
            index_type=extents_type,
            layout_policy=layout_policy,
            ref=ref,
            const=const,
        ).var(field.name)


def mdspan_ref(field: Field, extents_type: PsType = PsUnsignedIntegerType(64)):
    from warnings import warn

    warn(
        "`mdspan_ref` is deprecated and will be removed in version 0.1. Use `std.mdspan.from_field` instead.",
        FutureWarning,
    )
    return StdMdspan.from_field(field, extents_type, ref=True)