Skip to content
Snippets Groups Projects
kernelfunction.py 7.64 KiB
from __future__ import annotations

from warnings import warn
from abc import ABC
from typing import Callable, Sequence, Any

from .._deprecation import _deprecated

from .ast.structural import PsBlock

from .constraints import KernelParamsConstraint
from ..types import PsType
from .jit import JitBase, no_jit

from ..enums import Target
from ..field import Field
from ..sympyextensions import TypedSymbol
from ..sympyextensions.typed_sympy import (
    FieldShapeSymbol,
    FieldStrideSymbol,
    FieldPointerSymbol,
)


class KernelParameter:
    __match_args__ = ("name", "dtype")

    def __init__(self, name: str, dtype: PsType):
        self._name = name
        self._dtype = dtype

    @property
    def name(self):
        return self._name

    @property
    def dtype(self):
        return self._dtype

    def _hashable_contents(self):
        return (self._name, self._dtype)

    def __hash__(self) -> int:
        return hash(self._hashable_contents())

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, KernelParameter):
            return False

        return (
            type(self) is type(other)
            and self._hashable_contents() == other._hashable_contents()
        )

    def __str__(self) -> str:
        return self._name

    def __repr__(self) -> str:
        return f"{type(self).__name__}(name = {self._name}, dtype = {self._dtype})"

    @property
    def symbol(self) -> TypedSymbol:
        return TypedSymbol(self.name, self.dtype)

    @property
    def is_field_parameter(self) -> bool:
        warn(
            "`is_field_parameter` is deprecated and will be removed in a future version of pystencils. "
            "Use `isinstance(param, FieldParameter)` instead.",
            DeprecationWarning,
        )
        return isinstance(self, FieldParameter)

    @property
    def is_field_pointer(self) -> bool:
        warn(
            "`is_field_pointer` is deprecated and will be removed in a future version of pystencils. "
            "Use `isinstance(param, FieldPointerParam)` instead.",
            DeprecationWarning,
        )
        return isinstance(self, FieldPointerParam)

    @property
    def is_field_stride(self) -> bool:
        warn(
            "`is_field_stride` is deprecated and will be removed in a future version of pystencils. "
            "Use `isinstance(param, FieldStrideParam)` instead.",
            DeprecationWarning,
        )
        return isinstance(self, FieldStrideParam)

    @property
    def is_field_shape(self) -> bool:
        warn(
            "`is_field_shape` is deprecated and will be removed in a future version of pystencils. "
            "Use `isinstance(param, FieldShapeParam)` instead.",
            DeprecationWarning,
        )
        return isinstance(self, FieldShapeParam)


class FieldParameter(KernelParameter, ABC):
    __match_args__ = KernelParameter.__match_args__ + ("field",)

    def __init__(self, name: str, dtype: PsType, field: Field):
        super().__init__(name, dtype)
        self._field = field

    @property
    def field(self):
        return self._field

    @property
    def fields(self):
        warn(
            "`fields` is deprecated and will be removed in a future version of pystencils. "
            "In pystencils >= 2.0, field parameters are only associated with a single field."
            "Use the `field` property instead.",
            DeprecationWarning,
        )
        return [self._field]

    @property
    def field_name(self) -> str:
        warn(
            "`field_name` is deprecated and will be removed in a future version of pystencils. "
            "Use `field.name` instead.",
            DeprecationWarning,
        )
        return self._field.name

    def _hashable_contents(self):
        return super()._hashable_contents() + (self._field,)


class FieldShapeParam(FieldParameter):
    __match_args__ = FieldParameter.__match_args__ + ("coordinate",)

    def __init__(self, name: str, dtype: PsType, field: Field, coordinate: int):
        super().__init__(name, dtype, field)
        self._coordinate = coordinate

    @property
    def coordinate(self):
        return self._coordinate

    @property
    def symbol(self) -> FieldShapeSymbol:
        return FieldShapeSymbol(self.field.name, self.coordinate, self.dtype)

    def _hashable_contents(self):
        return super()._hashable_contents() + (self._coordinate,)


class FieldStrideParam(FieldParameter):
    __match_args__ = FieldParameter.__match_args__ + ("coordinate",)

    def __init__(self, name: str, dtype: PsType, field: Field, coordinate: int):
        super().__init__(name, dtype, field)
        self._coordinate = coordinate

    @property
    def coordinate(self):
        return self._coordinate

    @property
    def symbol(self) -> FieldStrideSymbol:
        return FieldStrideSymbol(self.field.name, self.coordinate, self.dtype)

    def _hashable_contents(self):
        return super()._hashable_contents() + (self._coordinate,)


class FieldPointerParam(FieldParameter):
    def __init__(self, name: str, dtype: PsType, field: Field):
        super().__init__(name, dtype, field)

    @property
    def symbol(self) -> FieldPointerSymbol:
        return FieldPointerSymbol(self.field.name, self.field.dtype, const=True)


class KernelFunction:
    """A pystencils kernel function.

    The kernel function is the final result of the translation process.
    It is immutable, and its AST should not be altered any more, either, as this
    might invalidate information about the kernel already stored in the `KernelFunction` object.
    """

    def __init__(
        self,
        body: PsBlock,
        target: Target,
        name: str,
        parameters: Sequence[KernelParameter],
        required_headers: set[str],
        constraints: Sequence[KernelParamsConstraint],
        jit: JitBase = no_jit,
    ):
        self._body: PsBlock = body
        self._target = target
        self._name = name
        self._params = tuple(parameters)
        self._required_headers = required_headers
        self._constraints = tuple(constraints)
        self._jit = jit
        self._metadata: dict[str, Any] = dict()

    @property
    def metadata(self) -> dict[str, Any]:
        return self._metadata

    @property
    def body(self) -> PsBlock:
        return self._body

    @property
    def target(self) -> Target:
        return self._target

    @property
    def name(self) -> str:
        return self._name

    @name.setter
    def name(self, n: str):
        self._name = n

    @property
    def function_name(self) -> str:
        _deprecated("function_name", "name")
        return self._name

    @function_name.setter
    def function_name(self, n: str):
        _deprecated("function_name", "name")
        self._name = n

    @property
    def parameters(self) -> tuple[KernelParameter, ...]:
        return self._params

    def get_parameters(self) -> tuple[KernelParameter, ...]:
        _deprecated("KernelFunction.get_parameters", "KernelFunction.parameters")
        return self.parameters

    def get_fields(self) -> set[Field]:
        return set(p.field for p in self._params if isinstance(p, FieldParameter))

    @property
    def fields_accessed(self) -> set[Field]:
        warn(
            "`fields_accessed` is deprecated and will be removed in a future version of pystencils. "
            "Use `get_fields` instead.",
            DeprecationWarning,
        )
        return self.get_fields()

    @property
    def required_headers(self) -> set[str]:
        return self._required_headers

    @property
    def constraints(self) -> tuple[KernelParamsConstraint, ...]:
        return self._constraints

    def compile(self) -> Callable[..., None]:
        return self._jit.compile(self)