Coverage for src/pystencilssfg/lang/cpp/std_mdspan.py: 90%

67 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-04 07:16 +0000

1from typing import cast 

2from sympy import Symbol 

3 

4from pystencils import Field, DynamicType 

5from pystencils.types import ( 

6 PsType, 

7 PsUnsignedIntegerType, 

8 UserTypeSpec, 

9 create_type, 

10) 

11 

12from pystencilssfg.lang.expressions import AugExpr 

13 

14from ...lang import SupportsFieldExtraction, cpptype, HeaderFile, ExprLike 

15 

16 

17class StdMdspan(AugExpr, SupportsFieldExtraction): 

18 """Represents an `std::mdspan` instance. 

19 

20 The `std::mdspan <https://en.cppreference.com/w/cpp/container/mdspan>`_ 

21 provides non-owning views into contiguous or strided n-dimensional arrays. 

22 It has been added to the C++ STL with the C++23 standard. 

23 As such, it is a natural data structure to target with pystencils kernels. 

24 

25 **Concerning Headers and Namespaces** 

26 

27 Since ``std::mdspan`` is not yet widely adopted 

28 (libc++ ships it as of LLVM 18, but GCC libstdc++ does not include it yet), 

29 you might have to manually include an implementation in your project 

30 (you can get a reference implementation at https://github.com/kokkos/mdspan). 

31 However, when working with a non-standard mdspan implementation, 

32 the path to its the header and the namespace it is defined in will likely be different. 

33 

34 To tell pystencils-sfg which headers to include and which namespace to use for ``mdspan``, 

35 use `StdMdspan.configure`; 

36 for instance, adding this call before creating any ``mdspan`` objects will 

37 set their namespace to `std::experimental`, and require ``<experimental/mdspan>`` to be imported: 

38 

39 >>> from pystencilssfg.lang.cpp import std 

40 >>> std.mdspan.configure("std::experimental", "<experimental/mdspan>") 

41 

42 **Creation from pystencils fields** 

43 

44 Using `from_field`, ``mdspan`` objects can be created directly from `Field <pystencils.Field>` instances. 

45 The `extents`_ of the ``mdspan`` type will be inferred from the field; 

46 each fixed entry in the field's shape will become a fixed entry of the ``mdspan``'s extents. 

47 

48 The ``mdspan``'s `layout_policy`_ defaults to `std::layout_stride`_, 

49 which might not be the optimal choice depending on the memory layout of your fields. 

50 You may therefore override this by specifying the name of the desired layout policy. 

51 To map pystencils field layout identifiers to layout policies, consult the following table: 

52 

53 +------------------------+--------------------------+ 

54 | pystencils Layout Name | ``mdspan`` Layout Policy | 

55 +========================+==========================+ 

56 | ``"fzyx"`` | `std::layout_left`_ | 

57 | ``"soa"`` | | 

58 | ``"f"`` | | 

59 | ``"reverse_numpy"`` | | 

60 +------------------------+--------------------------+ 

61 | ``"c"`` | `std::layout_right`_ | 

62 | ``"numpy"`` | | 

63 +------------------------+--------------------------+ 

64 | ``"zyxf"`` | `std::layout_stride`_ | 

65 | ``"aos"`` | | 

66 +------------------------+--------------------------+ 

67 

68 The array-of-structures (``"aos"``, ``"zyxf"``) layout has no equivalent layout policy in the C++ standard, 

69 so it can only be mapped onto ``layout_stride``. 

70 

71 .. _extents: https://en.cppreference.com/w/cpp/container/mdspan/extents 

72 .. _layout_policy: https://en.cppreference.com/w/cpp/named_req/LayoutMappingPolicy 

73 .. _std::layout_left: https://en.cppreference.com/w/cpp/container/mdspan/layout_left 

74 .. _std::layout_right: https://en.cppreference.com/w/cpp/container/mdspan/layout_right 

75 .. _std::layout_stride: https://en.cppreference.com/w/cpp/container/mdspan/layout_stride 

76 

77 Args: 

78 T: Element type of the mdspan 

79 """ 

80 

81 dynamic_extent = "std::dynamic_extent" 

82 

83 _namespace = "std" 

84 _template = cpptype("std::mdspan< {T}, {extents}, {layout_policy} >", "<mdspan>") 

85 

86 @classmethod 

87 def configure(cls, namespace: str = "std", header: str | HeaderFile = "<mdspan>"): 

88 """Configure the namespace and header ``std::mdspan`` is defined in.""" 

89 cls._namespace = namespace 

90 cls._template = cpptype( 

91 f"{namespace}::mdspan< { T} , { extents} , { layout_policy} >", header 

92 ) 

93 

94 def __init__( 

95 self, 

96 T: UserTypeSpec, 

97 extents: tuple[int | str, ...], 

98 index_type: UserTypeSpec = PsUnsignedIntegerType(64), 

99 layout_policy: str | None = None, 

100 ref: bool = False, 

101 const: bool = False, 

102 ): 

103 T = create_type(T) 

104 

105 extents_type_str = create_type(index_type).c_string() 

106 extents_str = f"{self._namespace}::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >" 

107 

108 if layout_policy is None: 

109 layout_policy = f"{self._namespace}::layout_stride" 

110 elif layout_policy in ("layout_left", "layout_right", "layout_stride"): 

111 layout_policy = f"{self._namespace}::{layout_policy}" 

112 

113 dtype = self._template( 

114 T=T, extents=extents_str, layout_policy=layout_policy, const=const, ref=ref 

115 ) 

116 super().__init__(dtype) 

117 

118 self._element_type = T 

119 self._extents_type = extents_str 

120 self._layout_type = layout_policy 

121 self._dim = len(extents) 

122 

123 @property 

124 def element_type(self) -> PsType: 

125 return self._element_type 

126 

127 @property 

128 def extents_type(self) -> str: 

129 return self._extents_type 

130 

131 @property 

132 def layout_type(self) -> str: 

133 return self._layout_type 

134 

135 def extent(self, r: int | ExprLike) -> AugExpr: 

136 return AugExpr.format("{}.extent({})", self, r) 

137 

138 def stride(self, r: int | ExprLike) -> AugExpr: 

139 return AugExpr.format("{}.stride({})", self, r) 

140 

141 def data_handle(self) -> AugExpr: 

142 return AugExpr.format("{}.data_handle()", self) 

143 

144 # SupportsFieldExtraction protocol 

145 

146 def _extract_ptr(self) -> AugExpr: 

147 return self.data_handle() 

148 

149 def _extract_size(self, coordinate: int) -> AugExpr | None: 

150 if coordinate > self._dim: 

151 return None 

152 else: 

153 return self.extent(coordinate) 

154 

155 def _extract_stride(self, coordinate: int) -> AugExpr | None: 

156 if coordinate > self._dim: 

157 return None 

158 else: 

159 return self.stride(coordinate) 

160 

161 @staticmethod 

162 def from_field( 

163 field: Field, 

164 extents_type: UserTypeSpec = PsUnsignedIntegerType(64), 

165 layout_policy: str | None = None, 

166 ref: bool = False, 

167 const: bool = False, 

168 ): 

169 """Creates a `std::mdspan` instance for a given pystencils field.""" 

170 if isinstance(field.dtype, DynamicType): 

171 raise ValueError("Cannot map dynamically typed field to std::mdspan") 

172 

173 extents: list[str | int] = [] 

174 

175 for s in field.spatial_shape: 

176 extents.append( 

177 StdMdspan.dynamic_extent if isinstance(s, Symbol) else cast(int, s) 

178 ) 

179 

180 for s in field.index_shape: 

181 extents.append(StdMdspan.dynamic_extent if isinstance(s, Symbol) else s) 

182 

183 return StdMdspan( 

184 field.dtype, 

185 tuple(extents), 

186 index_type=extents_type, 

187 layout_policy=layout_policy, 

188 ref=ref, 

189 const=const, 

190 ).var(field.name) 

191 

192 

193def mdspan_ref(field: Field, extents_type: PsType = PsUnsignedIntegerType(64)): 

194 from warnings import warn 

195 

196 warn( 

197 "`mdspan_ref` is deprecated and will be removed in version 0.1. Use `std.mdspan.from_field` instead.", 

198 FutureWarning, 

199 ) 

200 return StdMdspan.from_field(field, extents_type, ref=True)