Skip to content
Snippets Groups Projects
sycl_accessor.py 3.27 KiB
import math
from ...lang import SrcField, IFieldExtraction
from ...ir.source_components import SfgHeaderInclude
from typing import Sequence

from pystencils import Field
from pystencils.types import (
    PsType,
    PsCustomType,
)

from pystencilssfg.lang.expressions import AugExpr


class SyclAccessor(SrcField):
    def __init__(
        self,
        T: PsType,
        dimensions: int,
        index_shape: Sequence[int],
        reference: bool = False,
    ):
        cpp_typestr = T.c_string()
        if dimensions not in [1, 2, 3]:
            raise ValueError("sycl accessors can only have dims 1, 2 or 3")
        typestring = (
            f"sycl::accessor< {cpp_typestr}, {dimensions} > {'&' if reference else ''}"
        )
        super().__init__(PsCustomType(typestring))

        self._spatial_dimensions = dimensions
        self._index_dimensions = len(index_shape)
        self._index_shape = index_shape
        self._index_size = math.prod(index_shape)
        self._total_dimensions_ = self._spatial_dimensions + self._index_dimensions

    @property
    def required_includes(self) -> set[SfgHeaderInclude]:
        return {SfgHeaderInclude("sycl/sycl.hpp", system_header=True)}

    def get_extraction(self) -> IFieldExtraction:
        accessor = self

        class Extraction(IFieldExtraction):
            def ptr(self) -> AugExpr:
                return AugExpr.format(
                    "{}.get_multi_ptr<sycl::access::decorated::no>().get()",
                    accessor,
                )

            def size(self, coordinate: int) -> AugExpr | None:
                if coordinate > accessor._spatial_dimensions:
                    return None
                else:
                    return AugExpr.format(
                        "{}.get_range().get({})", accessor, coordinate
                    )

            def stride(self, coordinate: int) -> AugExpr | None:
                if coordinate > accessor._total_dimensions_:
                    return None
                elif coordinate >= accessor._spatial_dimensions - 1:
                    start = (coordinate - accessor._spatial_dimensions) + 1
                    return AugExpr.format(
                        "{}", math.prod(accessor._index_shape[start:])
                    )
                else:
                    exprs = []
                    args = []
                    for d in range(coordinate + 1, accessor._spatial_dimensions):
                        args.extend([accessor, d])
                        exprs.append("{}.get_range().get({})")
                    expr = " * ".join(exprs)
                    expr += " * {}"
                    return AugExpr.format(expr, *args, accessor._index_size)

        return Extraction()


def sycl_accessor_ref(field: Field):
    """Creates a `sycl::accessor &` for a given pystencils field."""
    # Sycl accesors allow only at max 3 dimensions:
    # So also mapping the index dimens to the sycl accesor we only can have 2D LBM stuff
    # In principle it would be possible to map it to something like sycl::buffer<std::array<double, 19>, 3>
    # but then would need to generate kernels that have sycl accessors as arguments

    return SyclAccessor(
        field.dtype,
        field.spatial_dimensions,
        field.index_shape,
        reference=True,
    ).var(field.name)