Coverage for src/pystencilssfg/lang/cpp/sycl_accessor.py: 92%

37 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-04 07:16 +0000

1from pystencils import Field, DynamicType 

2from pystencils.types import UserTypeSpec, create_type 

3 

4from ...lang import AugExpr, cpptype, SupportsFieldExtraction 

5 

6 

7class SyclAccessor(AugExpr, SupportsFieldExtraction): 

8 """Represent a 

9 `SYCL Accessor <https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#subsec:accessors>`_. 

10 

11 .. note:: 

12 

13 Sycl Accessor do not expose information about strides, so the linearization is done under 

14 the assumption that the underlying memory is contiguous, as descibed 

15 `here <https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#_multi_dimensional_objects_and_linearization>`_ 

16 """ # noqa: E501 

17 

18 _template = cpptype("sycl::accessor< {T}, {dims} >", "<sycl/sycl.hpp>") 

19 

20 def __init__( 

21 self, 

22 T: UserTypeSpec, 

23 dimensions: int, 

24 ref: bool = False, 

25 const: bool = False, 

26 ): 

27 T = create_type(T) 

28 if dimensions > 3: 

29 raise ValueError("sycl accessors can only have dims 1, 2 or 3") 

30 dtype = self._template(T=T, dims=dimensions, const=const, ref=ref) 

31 

32 super().__init__(dtype) 

33 

34 self._dim = dimensions 

35 self._inner_stride = 1 

36 

37 def _extract_ptr(self) -> AugExpr: 

38 return AugExpr.format( 

39 "{}.get_multi_ptr<sycl::access::decorated::no>().get()", 

40 self, 

41 ) 

42 

43 def _extract_size(self, coordinate: int) -> AugExpr | None: 

44 if coordinate > self._dim: 

45 return None 

46 else: 

47 return AugExpr.format("{}.get_range().get({})", self, coordinate) 

48 

49 def _extract_stride(self, coordinate: int) -> AugExpr | None: 

50 if coordinate > self._dim: 

51 return None 

52 elif coordinate == self._dim - 1: 

53 return AugExpr.format("{}", self._inner_stride) 

54 else: 

55 exprs = [] 

56 args = [] 

57 for d in range(coordinate + 1, self._dim): 

58 args.extend([self, d]) 

59 exprs.append("{}.get_range().get({})") 

60 expr = " * ".join(exprs) 

61 expr += " * {}" 

62 return AugExpr.format(expr, *args, self._inner_stride) 

63 

64 @staticmethod 

65 def from_field(field: Field, ref: bool = True): 

66 """Creates a `sycl::accessor &` for a given pystencils field.""" 

67 

68 if isinstance(field.dtype, DynamicType): 

69 raise ValueError("Cannot map dynamically typed field to sycl::accessor") 

70 

71 return SyclAccessor( 

72 field.dtype, 

73 field.spatial_dimensions + field.index_dimensions, 

74 ref=ref, 

75 ).var(field.name)