Coverage for src/pystencilssfg/lang/cpp/std_vector.py: 85%

39 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-04 07:16 +0000

1from pystencils import Field, DynamicType 

2from pystencils.types import UserTypeSpec, create_type, PsType 

3 

4from ...lang import SupportsFieldExtraction, SupportsVectorExtraction, AugExpr, cpptype 

5 

6 

7class StdVector(AugExpr, SupportsFieldExtraction, SupportsVectorExtraction): 

8 _template = cpptype("std::vector< {T} >", "<vector>") 

9 

10 def __init__( 

11 self, 

12 T: UserTypeSpec, 

13 unsafe: bool = False, 

14 ref: bool = False, 

15 const: bool = False, 

16 ): 

17 T = create_type(T) 

18 dtype = self._template(T=T, const=const, ref=ref) 

19 super().__init__(dtype) 

20 

21 self._element_type = T 

22 self._unsafe = unsafe 

23 

24 @property 

25 def element_type(self) -> PsType: 

26 return self._element_type 

27 

28 def _extract_ptr(self) -> AugExpr: 

29 return AugExpr.format("{}.data()", self) 

30 

31 def _extract_size(self, coordinate: int) -> AugExpr | None: 

32 if coordinate > 0: 

33 return None 

34 else: 

35 return AugExpr.format("{}.size()", self) 

36 

37 def _extract_stride(self, coordinate: int) -> AugExpr | None: 

38 if coordinate > 0: 

39 return None 

40 else: 

41 return AugExpr.format("1") 

42 

43 def _extract_component(self, coordinate: int) -> AugExpr: 

44 if self._unsafe: 

45 return AugExpr.format("{}[{}]", self, coordinate) 

46 else: 

47 return AugExpr.format("{}.at({})", self, coordinate) 

48 

49 @staticmethod 

50 def from_field(field: Field, ref: bool = True, const: bool = False): 

51 if field.spatial_dimensions > 1 or field.index_shape not in ((), (1,)): 

52 raise ValueError( 

53 f"Cannot create std::vector from more-than-one-dimensional field {field}." 

54 ) 

55 

56 if isinstance(field.dtype, DynamicType): 

57 raise ValueError("Cannot map dynamically typed field to std::vector") 

58 

59 return StdVector(field.dtype, unsafe=False, ref=ref, const=const).var( 

60 field.name 

61 ) 

62 

63 

64def std_vector_ref(field: Field): 

65 from warnings import warn 

66 

67 warn( 

68 "`std_vector_ref` is deprecated and will be removed in version 0.1. Use `std.vector.from_field` instead.", 

69 FutureWarning, 

70 ) 

71 return StdVector.from_field(field, ref=True)