Skip to content
Snippets Groups Projects
std_span.py 1.82 KiB
from pystencils.field import Field
from pystencils.types import UserTypeSpec, create_type, PsType

from ...lang import SrcField, IFieldExtraction, AugExpr, cpptype


class StdSpan(SrcField):
    _template = cpptype("std::span< {T} >", "<span>")

    def __init__(self, T: UserTypeSpec, ref=False, const=False):
        T = create_type(T)
        dtype = self._template(T=T, const=const, ref=ref)
        self._element_type = T
        super().__init__(dtype)

    @property
    def element_type(self) -> PsType:
        return self._element_type

    def get_extraction(self) -> IFieldExtraction:
        span = self

        class Extraction(IFieldExtraction):
            def ptr(self) -> AugExpr:
                return AugExpr.format("{}.data()", span)

            def size(self, coordinate: int) -> AugExpr | None:
                if coordinate > 0:
                    return None
                else:
                    return AugExpr.format("{}.size()", span)

            def stride(self, coordinate: int) -> AugExpr | None:
                if coordinate > 0:
                    return None
                else:
                    return AugExpr.format("1")

        return Extraction()

    @staticmethod
    def from_field(field: Field, ref: bool = False, const: bool = False):
        if field.spatial_dimensions > 1 or field.index_shape not in ((), (1,)):
            raise ValueError(
                "Only one-dimensional fields with trivial index dimensions can be mapped onto `std::span`"
            )
        return StdSpan(field.dtype, ref=ref, const=const).var(field.name)


def std_span_ref(field: Field):
    from warnings import warn

    warn(
        "`std_span_ref` is deprecated and will be removed in version 0.1. Use `std.span.from_field` instead.",
        FutureWarning,
    )
    return StdSpan.from_field(field, ref=True)