Coverage for src/pystencilssfg/lang/cpp/std_vector.py: 85%
39 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
1from pystencils import Field, DynamicType
2from pystencils.types import UserTypeSpec, create_type, PsType
4from ...lang import SupportsFieldExtraction, SupportsVectorExtraction, AugExpr, cpptype
7class StdVector(AugExpr, SupportsFieldExtraction, SupportsVectorExtraction):
8 _template = cpptype("std::vector< {T} >", "<vector>")
10 def __init__(
11 self,
12 T: UserTypeSpec,
13 unsafe: bool = False,
14 ref: bool = False,
15 const: bool = False,
16 ):
17 T = create_type(T)
18 dtype = self._template(T=T, const=const, ref=ref)
19 super().__init__(dtype)
21 self._element_type = T
22 self._unsafe = unsafe
24 @property
25 def element_type(self) -> PsType:
26 return self._element_type
28 def _extract_ptr(self) -> AugExpr:
29 return AugExpr.format("{}.data()", self)
31 def _extract_size(self, coordinate: int) -> AugExpr | None:
32 if coordinate > 0:
33 return None
34 else:
35 return AugExpr.format("{}.size()", self)
37 def _extract_stride(self, coordinate: int) -> AugExpr | None:
38 if coordinate > 0:
39 return None
40 else:
41 return AugExpr.format("1")
43 def _extract_component(self, coordinate: int) -> AugExpr:
44 if self._unsafe:
45 return AugExpr.format("{}[{}]", self, coordinate)
46 else:
47 return AugExpr.format("{}.at({})", self, coordinate)
49 @staticmethod
50 def from_field(field: Field, ref: bool = True, const: bool = False):
51 if field.spatial_dimensions > 1 or field.index_shape not in ((), (1,)):
52 raise ValueError(
53 f"Cannot create std::vector from more-than-one-dimensional field {field}."
54 )
56 if isinstance(field.dtype, DynamicType):
57 raise ValueError("Cannot map dynamically typed field to std::vector")
59 return StdVector(field.dtype, unsafe=False, ref=ref, const=const).var(
60 field.name
61 )
64def std_vector_ref(field: Field):
65 from warnings import warn
67 warn(
68 "`std_vector_ref` is deprecated and will be removed in version 0.1. Use `std.vector.from_field` instead.",
69 FutureWarning,
70 )
71 return StdVector.from_field(field, ref=True)