Skip to content
Snippets Groups Projects
sycl.py 9.36 KiB
from __future__ import annotations
from typing import Sequence
from enum import Enum
import re

from pystencils.types import UserTypeSpec, PsType, PsCustomType, create_type
from pystencils import Target

from pystencilssfg.composer.basic_composer import SequencerArg

from ..config import CodeStyle
from ..exceptions import SfgException
from ..context import SfgContext
from ..composer import (
    SfgBasicComposer,
    SfgClassComposer,
    SfgComposer,
    SfgComposerMixIn,
    make_sequence,
)
from ..ir import (
    SfgKernelHandle,
    SfgCallTreeNode,
    SfgCallTreeLeaf,
    SfgKernelCallNode,
)

from ..lang import SfgVar, AugExpr, cpptype, Ref, VarLike, _VarLike, asvar
from ..lang.cpp.sycl_accessor import SyclAccessor


accessor = SyclAccessor


class SyclComposerMixIn(SfgComposerMixIn):
    """Composer mix-in for SYCL code generation"""

    def sycl_handler(self, name: str) -> SyclHandler:
        """Obtain a `SyclHandler`, which represents a ``sycl::handler`` object."""
        return SyclHandler(self._ctx).var(name)

    def sycl_group(self, dims: int, name: str) -> SyclGroup:
        """Obtain a `SyclHandler`, which represents a ``sycl::handler`` object."""
        return SyclGroup(dims, self._ctx).var(name)

    def sycl_range(self, dims: int, name: str, ref: bool = False) -> SyclRange:
        return SyclRange(dims, ref=ref).var(name)


class SyclComposer(SfgBasicComposer, SfgClassComposer, SyclComposerMixIn):
    """Composer extension providing SYCL code generation capabilities"""

    def __init__(self, sfg: SfgContext | SfgComposer):
        super().__init__(sfg)


class SyclRange(AugExpr):
    _template = cpptype("sycl::range< {dims} >", "<sycl/sycl.hpp>")

    def __init__(self, dims: int, const: bool = False, ref: bool = False):
        dtype = self._template(dims=dims, const=const, ref=ref)
        super().__init__(dtype)


class SyclHandler(AugExpr):
    """Represents a SYCL command group handler (``sycl::handler``)."""

    _type = cpptype("sycl::handler", "<sycl/sycl.hpp>")

    def __init__(self, ctx: SfgContext):
        dtype = Ref(self._type())
        super().__init__(dtype)

        self._ctx = ctx

    def parallel_for(
        self,
        range: VarLike | Sequence[int],
    ):
        """Generate a ``parallel_for`` kernel invocation using this command group handler.
        The syntax of this uses a chain of two calls to mimic C++ syntax:

        .. code-block:: Python

            sfg.parallel_for(range)(
                # Body
            )

        The body is constructed via sequencing (see `make_sequence`).

        Args:
            range: Object, or tuple of integers, indicating the kernel's iteration range
        """
        if isinstance(range, _VarLike):
            range = asvar(range)

        def check_kernel(khandle: SfgKernelHandle):
            kfunc = khandle.kernel
            if kfunc.target != Target.SYCL:
                raise SfgException(
                    f"Kernel given to `parallel_for` is no SYCL kernel: {khandle.fqname}"
                )

        id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>")

        def filter_id(param: SfgVar) -> bool:
            return (
                isinstance(param.dtype, PsCustomType)
                and id_regex.search(param.dtype.c_string()) is not None
            )

        def sequencer(*args: SequencerArg):
            id_param = []
            for arg in args:
                if isinstance(arg, SfgKernelCallNode):
                    check_kernel(arg._kernel_handle)
                    id_param.append(
                        list(filter(filter_id, arg._kernel_handle.scalar_parameters))[0]
                    )

            if not all(item == id_param[0] for item in id_param):
                raise ValueError(
                    "id_param should be the same for all kernels in parallel_for"
                )
            tree = make_sequence(*args)

            kernel_lambda = SfgLambda(("=",), (id_param[0],), tree, None)
            return SyclKernelInvoke(
                self, SyclInvokeType.ParallelFor, range, kernel_lambda
            )

        return sequencer


class SyclGroup(AugExpr):
    """Represents a SYCL group (``sycl::group``)."""

    _template = cpptype("sycl::group< {dims} >", "<sycl/sycl.hpp>")

    def __init__(self, dimensions: int, ctx: SfgContext):
        dtype = Ref(self._template(dims=dimensions))
        super().__init__(dtype)

        self._dimensions = dimensions
        self._ctx = ctx

    def parallel_for_work_item(
        self, range: VarLike | Sequence[int], khandle: SfgKernelHandle
    ):
        """Generate a ``parallel_for_work_item` kernel invocation on this group.`

        Args:
            range: Object, or tuple of integers, indicating the kernel's iteration range
            kernel: Handle to the pystencils-kernel to be executed
        """
        if isinstance(range, _VarLike):
            range = asvar(range)

        kfunc = khandle.kernel
        if kfunc.target != Target.SYCL:
            raise SfgException(
                f"Kernel given to `parallel_for` is no SYCL kernel: {khandle.fqname}"
            )

        id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>")

        def filter_id(param: SfgVar) -> bool:
            return (
                isinstance(param.dtype, PsCustomType)
                and id_regex.search(param.dtype.c_string()) is not None
            )

        id_param = list(filter(filter_id, khandle.scalar_parameters))[0]
        h_item = SfgVar("item", PsCustomType("sycl::h_item< 3 >"))

        comp = SfgComposer(self._ctx)
        tree = comp.seq(
            comp.set_param(id_param, AugExpr.format("{}.get_local_id()", h_item)),
            SfgKernelCallNode(khandle),
        )

        kernel_lambda = SfgLambda(("=",), (h_item,), tree, None)
        invoke = SyclKernelInvoke(
            self, SyclInvokeType.ParallelForWorkItem, range, kernel_lambda
        )
        return invoke


class SfgLambda:
    """Models a C++ lambda expression"""

    def __init__(
        self,
        captures: Sequence[str],
        params: Sequence[SfgVar],
        tree: SfgCallTreeNode,
        return_type: UserTypeSpec | None = None,
    ) -> None:
        self._captures = tuple(captures)
        self._params = tuple(params)
        self._tree = tree
        self._return_type: PsType | None = (
            create_type(return_type) if return_type is not None else None
        )

        from ..ir.postprocessing import CallTreePostProcessing

        postprocess = CallTreePostProcessing()
        self._required_params = postprocess(self._tree).function_params - set(
            self._params
        )

    @property
    def captures(self) -> tuple[str, ...]:
        return self._captures

    @property
    def parameters(self) -> tuple[SfgVar, ...]:
        return self._params

    @property
    def body(self) -> SfgCallTreeNode:
        return self._tree

    @property
    def return_type(self) -> PsType | None:
        return self._return_type

    @property
    def required_parameters(self) -> set[SfgVar]:
        return self._required_params

    def get_code(self, cstyle: CodeStyle):
        captures = ", ".join(self._captures)
        params = ", ".join(f"{p.dtype.c_string()} {p.name}" for p in self._params)
        body = self._tree.get_code(cstyle)
        body = cstyle.indent(body)
        rtype = (
            f"-> {self._return_type.c_string()} "
            if self._return_type is not None
            else ""
        )

        return f"[{captures}] ({params}) {rtype}{{\n{body}\n}}"


class SyclInvokeType(Enum):
    ParallelFor = ("parallel_for", SyclHandler)
    ParallelForWorkItem = ("parallel_for_work_item", SyclGroup)

    @property
    def method(self) -> str:
        return self.value[0]

    @property
    def invoker_class(self) -> type:
        return self.value[1]


class SyclKernelInvoke(SfgCallTreeLeaf):
    """A SYCL kernel invocation on a given handler or group"""

    def __init__(
        self,
        invoker: SyclHandler | SyclGroup,
        invoke_type: SyclInvokeType,
        range: SfgVar | Sequence[int],
        lamb: SfgLambda,
    ):
        if not isinstance(invoker, invoke_type.invoker_class):
            raise SfgException(
                f"Cannot invoke kernel via `{invoke_type.method}` on a {type(invoker)}"
            )

        super().__init__()
        self._invoker = invoker
        self._invoke_type = invoke_type
        self._range: SfgVar | tuple[int, ...] = (
            range if isinstance(range, SfgVar) else tuple(range)
        )
        self._lambda = lamb

        self._required_params = set(invoker.depends | lamb.required_parameters)

        if isinstance(range, SfgVar):
            self._required_params.add(range)

    @property
    def invoker(self) -> SyclHandler | SyclGroup:
        return self._invoker

    @property
    def range(self) -> SfgVar | tuple[int, ...]:
        return self._range

    @property
    def kernel(self) -> SfgLambda:
        return self._lambda

    @property
    def depends(self) -> set[SfgVar]:
        return self._required_params

    def get_code(self, cstyle: CodeStyle) -> str:
        if isinstance(self._range, SfgVar):
            range_code = self._range.name
        else:
            range_code = "{ " + ", ".join(str(r) for r in self._range) + " }"

        kernel_code = self._lambda.get_code(cstyle)
        invoker = str(self._invoker)
        method = self._invoke_type.method

        return f"{invoker}.{method}({range_code}, {kernel_code});"