Coverage for src/pystencilssfg/lang/cpp/std_mdspan.py: 90%
67 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
1from typing import cast
2from sympy import Symbol
4from pystencils import Field, DynamicType
5from pystencils.types import (
6 PsType,
7 PsUnsignedIntegerType,
8 UserTypeSpec,
9 create_type,
10)
12from pystencilssfg.lang.expressions import AugExpr
14from ...lang import SupportsFieldExtraction, cpptype, HeaderFile, ExprLike
17class StdMdspan(AugExpr, SupportsFieldExtraction):
18 """Represents an `std::mdspan` instance.
20 The `std::mdspan <https://en.cppreference.com/w/cpp/container/mdspan>`_
21 provides non-owning views into contiguous or strided n-dimensional arrays.
22 It has been added to the C++ STL with the C++23 standard.
23 As such, it is a natural data structure to target with pystencils kernels.
25 **Concerning Headers and Namespaces**
27 Since ``std::mdspan`` is not yet widely adopted
28 (libc++ ships it as of LLVM 18, but GCC libstdc++ does not include it yet),
29 you might have to manually include an implementation in your project
30 (you can get a reference implementation at https://github.com/kokkos/mdspan).
31 However, when working with a non-standard mdspan implementation,
32 the path to its the header and the namespace it is defined in will likely be different.
34 To tell pystencils-sfg which headers to include and which namespace to use for ``mdspan``,
35 use `StdMdspan.configure`;
36 for instance, adding this call before creating any ``mdspan`` objects will
37 set their namespace to `std::experimental`, and require ``<experimental/mdspan>`` to be imported:
39 >>> from pystencilssfg.lang.cpp import std
40 >>> std.mdspan.configure("std::experimental", "<experimental/mdspan>")
42 **Creation from pystencils fields**
44 Using `from_field`, ``mdspan`` objects can be created directly from `Field <pystencils.Field>` instances.
45 The `extents`_ of the ``mdspan`` type will be inferred from the field;
46 each fixed entry in the field's shape will become a fixed entry of the ``mdspan``'s extents.
48 The ``mdspan``'s `layout_policy`_ defaults to `std::layout_stride`_,
49 which might not be the optimal choice depending on the memory layout of your fields.
50 You may therefore override this by specifying the name of the desired layout policy.
51 To map pystencils field layout identifiers to layout policies, consult the following table:
53 +------------------------+--------------------------+
54 | pystencils Layout Name | ``mdspan`` Layout Policy |
55 +========================+==========================+
56 | ``"fzyx"`` | `std::layout_left`_ |
57 | ``"soa"`` | |
58 | ``"f"`` | |
59 | ``"reverse_numpy"`` | |
60 +------------------------+--------------------------+
61 | ``"c"`` | `std::layout_right`_ |
62 | ``"numpy"`` | |
63 +------------------------+--------------------------+
64 | ``"zyxf"`` | `std::layout_stride`_ |
65 | ``"aos"`` | |
66 +------------------------+--------------------------+
68 The array-of-structures (``"aos"``, ``"zyxf"``) layout has no equivalent layout policy in the C++ standard,
69 so it can only be mapped onto ``layout_stride``.
71 .. _extents: https://en.cppreference.com/w/cpp/container/mdspan/extents
72 .. _layout_policy: https://en.cppreference.com/w/cpp/named_req/LayoutMappingPolicy
73 .. _std::layout_left: https://en.cppreference.com/w/cpp/container/mdspan/layout_left
74 .. _std::layout_right: https://en.cppreference.com/w/cpp/container/mdspan/layout_right
75 .. _std::layout_stride: https://en.cppreference.com/w/cpp/container/mdspan/layout_stride
77 Args:
78 T: Element type of the mdspan
79 """
81 dynamic_extent = "std::dynamic_extent"
83 _namespace = "std"
84 _template = cpptype("std::mdspan< {T}, {extents}, {layout_policy} >", "<mdspan>")
86 @classmethod
87 def configure(cls, namespace: str = "std", header: str | HeaderFile = "<mdspan>"):
88 """Configure the namespace and header ``std::mdspan`` is defined in."""
89 cls._namespace = namespace
90 cls._template = cpptype(
91 f"{namespace}::mdspan< { T} , { extents} , { layout_policy} >", header
92 )
94 def __init__(
95 self,
96 T: UserTypeSpec,
97 extents: tuple[int | str, ...],
98 index_type: UserTypeSpec = PsUnsignedIntegerType(64),
99 layout_policy: str | None = None,
100 ref: bool = False,
101 const: bool = False,
102 ):
103 T = create_type(T)
105 extents_type_str = create_type(index_type).c_string()
106 extents_str = f"{self._namespace}::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >"
108 if layout_policy is None:
109 layout_policy = f"{self._namespace}::layout_stride"
110 elif layout_policy in ("layout_left", "layout_right", "layout_stride"):
111 layout_policy = f"{self._namespace}::{layout_policy}"
113 dtype = self._template(
114 T=T, extents=extents_str, layout_policy=layout_policy, const=const, ref=ref
115 )
116 super().__init__(dtype)
118 self._element_type = T
119 self._extents_type = extents_str
120 self._layout_type = layout_policy
121 self._dim = len(extents)
123 @property
124 def element_type(self) -> PsType:
125 return self._element_type
127 @property
128 def extents_type(self) -> str:
129 return self._extents_type
131 @property
132 def layout_type(self) -> str:
133 return self._layout_type
135 def extent(self, r: int | ExprLike) -> AugExpr:
136 return AugExpr.format("{}.extent({})", self, r)
138 def stride(self, r: int | ExprLike) -> AugExpr:
139 return AugExpr.format("{}.stride({})", self, r)
141 def data_handle(self) -> AugExpr:
142 return AugExpr.format("{}.data_handle()", self)
144 # SupportsFieldExtraction protocol
146 def _extract_ptr(self) -> AugExpr:
147 return self.data_handle()
149 def _extract_size(self, coordinate: int) -> AugExpr | None:
150 if coordinate > self._dim:
151 return None
152 else:
153 return self.extent(coordinate)
155 def _extract_stride(self, coordinate: int) -> AugExpr | None:
156 if coordinate > self._dim:
157 return None
158 else:
159 return self.stride(coordinate)
161 @staticmethod
162 def from_field(
163 field: Field,
164 extents_type: UserTypeSpec = PsUnsignedIntegerType(64),
165 layout_policy: str | None = None,
166 ref: bool = False,
167 const: bool = False,
168 ):
169 """Creates a `std::mdspan` instance for a given pystencils field."""
170 if isinstance(field.dtype, DynamicType):
171 raise ValueError("Cannot map dynamically typed field to std::mdspan")
173 extents: list[str | int] = []
175 for s in field.spatial_shape:
176 extents.append(
177 StdMdspan.dynamic_extent if isinstance(s, Symbol) else cast(int, s)
178 )
180 for s in field.index_shape:
181 extents.append(StdMdspan.dynamic_extent if isinstance(s, Symbol) else s)
183 return StdMdspan(
184 field.dtype,
185 tuple(extents),
186 index_type=extents_type,
187 layout_policy=layout_policy,
188 ref=ref,
189 const=const,
190 ).var(field.name)
193def mdspan_ref(field: Field, extents_type: PsType = PsUnsignedIntegerType(64)):
194 from warnings import warn
196 warn(
197 "`mdspan_ref` is deprecated and will be removed in version 0.1. Use `std.mdspan.from_field` instead.",
198 FutureWarning,
199 )
200 return StdMdspan.from_field(field, extents_type, ref=True)