Coverage for src/pystencilssfg/lang/cpp/sycl_accessor.py: 92%
37 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
1from pystencils import Field, DynamicType
2from pystencils.types import UserTypeSpec, create_type
4from ...lang import AugExpr, cpptype, SupportsFieldExtraction
7class SyclAccessor(AugExpr, SupportsFieldExtraction):
8 """Represent a
9 `SYCL Accessor <https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#subsec:accessors>`_.
11 .. note::
13 Sycl Accessor do not expose information about strides, so the linearization is done under
14 the assumption that the underlying memory is contiguous, as descibed
15 `here <https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#_multi_dimensional_objects_and_linearization>`_
16 """ # noqa: E501
18 _template = cpptype("sycl::accessor< {T}, {dims} >", "<sycl/sycl.hpp>")
20 def __init__(
21 self,
22 T: UserTypeSpec,
23 dimensions: int,
24 ref: bool = False,
25 const: bool = False,
26 ):
27 T = create_type(T)
28 if dimensions > 3:
29 raise ValueError("sycl accessors can only have dims 1, 2 or 3")
30 dtype = self._template(T=T, dims=dimensions, const=const, ref=ref)
32 super().__init__(dtype)
34 self._dim = dimensions
35 self._inner_stride = 1
37 def _extract_ptr(self) -> AugExpr:
38 return AugExpr.format(
39 "{}.get_multi_ptr<sycl::access::decorated::no>().get()",
40 self,
41 )
43 def _extract_size(self, coordinate: int) -> AugExpr | None:
44 if coordinate > self._dim:
45 return None
46 else:
47 return AugExpr.format("{}.get_range().get({})", self, coordinate)
49 def _extract_stride(self, coordinate: int) -> AugExpr | None:
50 if coordinate > self._dim:
51 return None
52 elif coordinate == self._dim - 1:
53 return AugExpr.format("{}", self._inner_stride)
54 else:
55 exprs = []
56 args = []
57 for d in range(coordinate + 1, self._dim):
58 args.extend([self, d])
59 exprs.append("{}.get_range().get({})")
60 expr = " * ".join(exprs)
61 expr += " * {}"
62 return AugExpr.format(expr, *args, self._inner_stride)
64 @staticmethod
65 def from_field(field: Field, ref: bool = True):
66 """Creates a `sycl::accessor &` for a given pystencils field."""
68 if isinstance(field.dtype, DynamicType):
69 raise ValueError("Cannot map dynamically typed field to sycl::accessor")
71 return SyclAccessor(
72 field.dtype,
73 field.spatial_dimensions + field.index_dimensions,
74 ref=ref,
75 ).var(field.name)