from ...lang import SrcField, IFieldExtraction from pystencils import Field, DynamicType from pystencils.types import UserTypeSpec, create_type from ...lang import AugExpr, cpptype class SyclAccessor(SrcField): """Represent a `SYCL Accessor <https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#subsec:accessors>`_. .. note:: Sycl Accessor do not expose information about strides, so the linearization is done under the assumption that the underlying memory is contiguous, as descibed `here <https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#_multi_dimensional_objects_and_linearization>`_ """ # noqa: E501 _template = cpptype("sycl::accessor< {T}, {dims} >", "<sycl/sycl.hpp>") def __init__( self, T: UserTypeSpec, dimensions: int, ref: bool = False, const: bool = False, ): T = create_type(T) if dimensions > 3: raise ValueError("sycl accessors can only have dims 1, 2 or 3") dtype = self._template(T=T, dims=dimensions, const=const, ref=ref) super().__init__(dtype) self._dim = dimensions self._inner_stride = 1 def get_extraction(self) -> IFieldExtraction: accessor = self class Extraction(IFieldExtraction): def ptr(self) -> AugExpr: return AugExpr.format( "{}.get_multi_ptr<sycl::access::decorated::no>().get()", accessor, ) def size(self, coordinate: int) -> AugExpr | None: if coordinate > accessor._dim: return None else: return AugExpr.format( "{}.get_range().get({})", accessor, coordinate ) def stride(self, coordinate: int) -> AugExpr | None: if coordinate > accessor._dim: return None elif coordinate == accessor._dim - 1: return AugExpr.format("{}", accessor._inner_stride) else: exprs = [] args = [] for d in range(coordinate + 1, accessor._dim): args.extend([accessor, d]) exprs.append("{}.get_range().get({})") expr = " * ".join(exprs) expr += " * {}" return AugExpr.format(expr, *args, accessor._inner_stride) return Extraction() @staticmethod def from_field(field: Field, ref: bool = True): """Creates a `sycl::accessor &` for a given pystencils field.""" if isinstance(field.dtype, DynamicType): raise ValueError("Cannot map dynamically typed field to sycl::accessor") return SyclAccessor( field.dtype, field.spatial_dimensions + field.index_dimensions, ref=ref, ).var(field.name)