diff --git a/docs/source/backend/objects.rst b/docs/source/backend/objects.rst index a28ac3b8534b32d0ef9df946c8980cf97eeaea8c..a39f4c24b1e7b0ceae4f24bdc5b8869b69c080ac 100644 --- a/docs/source/backend/objects.rst +++ b/docs/source/backend/objects.rst @@ -62,14 +62,12 @@ For example, this snippet defines a property type that models pointer alignment .. code-block:: python @dataclass(frozen=True) - class AlignmentProperty(PsSymbolProperty) + class AlignmentProperty(UniqueSymbolProperty) """Require this pointer symbol to be aligned at a particular byte boundary.""" byte_boundary: int - _unique: ClassVar[bool] = True - -The ``_unique`` flag in the above example ensures that only one property of this type can at any time -be attached to a symbol. +Inheriting from `UniqueSymbolProperty` ensures that at most one property of this type can be attached to +a symbol at any time. Properties can be added, queried, and removed using the `PsSymbol` properties API listed below. Many symbol properties are more relevant to consumers of generated kernels than to the code generator itself. @@ -87,15 +85,12 @@ Constants and Literals API Documentation ================= -The `memory <pystencils.backend.memory>` Module ------------------------------------------------ +.. automodule:: pystencils.backend.properties + :members: .. automodule:: pystencils.backend.memory :members: -The `constants <pystencils.backend.constants>` Module ------------------------------------------------------ - .. automodule:: pystencils.backend.constants :members: diff --git a/src/pystencils/backend/__init__.py b/src/pystencils/backend/__init__.py index a0b1c8f747984e3fffde5a336f40e2aa46ad631d..b947a112ecb2be7762fefdf54afd4dffc185c319 100644 --- a/src/pystencils/backend/__init__.py +++ b/src/pystencils/backend/__init__.py @@ -1,9 +1,5 @@ from .kernelfunction import ( KernelParameter, - FieldParameter, - FieldShapeParam, - FieldStrideParam, - FieldPointerParam, KernelFunction, GpuKernelFunction, ) @@ -12,10 +8,6 @@ from .constraints import KernelParamsConstraint __all__ = [ "KernelParameter", - "FieldParameter", - "FieldShapeParam", - "FieldStrideParam", - "FieldPointerParam", "KernelFunction", "GpuKernelFunction", "KernelParamsConstraint", diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 558cc8a0e667dd8896a9fc63a82693a7f3e8fd08..839b8fd9829a83b46dbe2419013959b42943b96c 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -9,7 +9,8 @@ from ...defaults import DEFAULTS from ...field import Field, FieldType from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType -from ..memory import PsSymbol, PsBuffer, FieldShape, FieldStride +from ..memory import PsSymbol, PsBuffer +from ..properties import FieldShape, FieldStride from ..constants import PsConstant from ...types import ( PsType, diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py index 568b6c7994add39d9e952cad1364db51f04ceefa..a5bdab623380136fc77946c111c83bcace66c1a6 100644 --- a/src/pystencils/backend/kernelfunction.py +++ b/src/pystencils/backend/kernelfunction.py @@ -1,7 +1,6 @@ from __future__ import annotations from warnings import warn -from abc import ABC from typing import Callable, Sequence, Iterable, Any, TYPE_CHECKING from .._deprecation import _deprecated @@ -9,6 +8,13 @@ from .._deprecation import _deprecated from .ast.structural import PsBlock from .ast.analysis import collect_required_headers, collect_undefined_symbols from .memory import PsSymbol +from .properties import ( + PsSymbolProperty, + _FieldProperty, + FieldShape, + FieldStride, + FieldBasePtr, +) from .kernelcreation.context import KernelCreationContext from .platforms import Platform, GpuThreadsRange @@ -24,11 +30,18 @@ if TYPE_CHECKING: class KernelParameter: - __match_args__ = ("name", "dtype") + """Parameter to a `KernelFunction`.""" - def __init__(self, name: str, dtype: PsType): + __match_args__ = ("name", "dtype", "properties") + + def __init__( + self, name: str, dtype: PsType, properties: Iterable[PsSymbolProperty] = () + ): self._name = name self._dtype = dtype + self._properties: frozenset[PsSymbolProperty] = ( + frozenset(properties) if properties is not None else frozenset() + ) @property def name(self): @@ -39,8 +52,9 @@ class KernelParameter: return self._dtype def _hashable_contents(self): - return (self._name, self._dtype) + return (self._name, self._dtype, self._properties) + # TODO: Need? def __hash__(self) -> int: return hash(self._hashable_contents()) @@ -63,110 +77,56 @@ class KernelParameter: def symbol(self) -> TypedSymbol: return TypedSymbol(self.name, self.dtype) + @property + def fields(self) -> set[Field]: + """Set of fields associated with this parameter.""" + return set(p.field for p in filter(lambda p: isinstance(p, _FieldProperty), self.properties)) # type: ignore + + def get_properties( + self, prop_type: type[PsSymbolProperty] + ) -> set[PsSymbolProperty]: + """Retrieve all properties of the given type attached to this parameter""" + return set(filter(lambda p: isinstance(p, prop_type), self._properties)) + + @property + def properties(self) -> frozenset[PsSymbolProperty]: + return self._properties + @property def is_field_parameter(self) -> bool: warn( "`is_field_parameter` is deprecated and will be removed in a future version of pystencils. " - "Use `isinstance(param, FieldParameter)` instead.", + "Check `param.fields` for emptiness instead.", DeprecationWarning, ) - return isinstance(self, FieldParameter) + return bool(self.fields) @property def is_field_pointer(self) -> bool: warn( "`is_field_pointer` is deprecated and will be removed in a future version of pystencils. " - "Use `isinstance(param, FieldPointerParam)` instead.", + "Use `param.get_properties(FieldBasePtr)` instead.", DeprecationWarning, ) - return isinstance(self, FieldPointerParam) + return bool(self.get_properties(FieldBasePtr)) @property def is_field_stride(self) -> bool: warn( "`is_field_stride` is deprecated and will be removed in a future version of pystencils. " - "Use `isinstance(param, FieldStrideParam)` instead.", + "Use `param.get_properties(FieldStride)` instead.", DeprecationWarning, ) - return isinstance(self, FieldStrideParam) + return bool(self.get_properties(FieldStride)) @property def is_field_shape(self) -> bool: warn( "`is_field_shape` is deprecated and will be removed in a future version of pystencils. " - "Use `isinstance(param, FieldShapeParam)` instead.", - DeprecationWarning, - ) - return isinstance(self, FieldShapeParam) - - -class FieldParameter(KernelParameter, ABC): - __match_args__ = KernelParameter.__match_args__ + ("field",) - - def __init__(self, name: str, dtype: PsType, field: Field): - super().__init__(name, dtype) - self._field = field - - @property - def field(self): - return self._field - - @property - def fields(self): - warn( - "`fields` is deprecated and will be removed in a future version of pystencils. " - "In pystencils >= 2.0, field parameters are only associated with a single field." - "Use the `field` property instead.", - DeprecationWarning, - ) - return [self._field] - - @property - def field_name(self) -> str: - warn( - "`field_name` is deprecated and will be removed in a future version of pystencils. " - "Use `field.name` instead.", + "Use `param.get_properties(FieldShape)` instead.", DeprecationWarning, ) - return self._field.name - - def _hashable_contents(self): - return super()._hashable_contents() + (self._field,) - - -class FieldShapeParam(FieldParameter): - __match_args__ = FieldParameter.__match_args__ + ("coordinate",) - - def __init__(self, name: str, dtype: PsType, field: Field, coordinate: int): - super().__init__(name, dtype, field) - self._coordinate = coordinate - - @property - def coordinate(self): - return self._coordinate - - def _hashable_contents(self): - return super()._hashable_contents() + (self._coordinate,) - - -class FieldStrideParam(FieldParameter): - __match_args__ = FieldParameter.__match_args__ + ("coordinate",) - - def __init__(self, name: str, dtype: PsType, field: Field, coordinate: int): - super().__init__(name, dtype, field) - self._coordinate = coordinate - - @property - def coordinate(self): - return self._coordinate - - def _hashable_contents(self): - return super()._hashable_contents() + (self._coordinate,) - - -class FieldPointerParam(FieldParameter): - def __init__(self, name: str, dtype: PsType, field: Field): - super().__init__(name, dtype, field) + return bool(self.get_properties(FieldShape)) class KernelFunction: @@ -235,7 +195,7 @@ class KernelFunction: return self.parameters def get_fields(self) -> set[Field]: - return set(p.field for p in self._params if isinstance(p, FieldParameter)) + return set.union(*(p.fields for p in self._params)) @property def fields_accessed(self) -> set[Field]: @@ -332,19 +292,19 @@ def create_gpu_kernel_function( def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]): params: list[KernelParameter] = [] + + from pystencils.backend.memory import BufferBasePtr + for symb in symbols: - match symb: - case PsArrayShapeSymbol(name, _, arr, coord): - field = ctx.find_field(arr.name) - params.append(FieldShapeParam(name, symb.get_dtype(), field, coord)) - case PsArrayStrideSymbol(name, _, arr, coord): - field = ctx.find_field(arr.name) - params.append(FieldStrideParam(name, symb.get_dtype(), field, coord)) - case PsArrayBasePointer(name, _, arr): - field = ctx.find_field(arr.name) - params.append(FieldPointerParam(name, symb.get_dtype(), field)) - case PsSymbol(name, _): - params.append(KernelParameter(name, symb.get_dtype())) + props: set[PsSymbolProperty] = set() + for prop in symb.properties: + match prop: + case FieldShape() | FieldStride(): + props.add(prop) + case BufferBasePtr(buf): + field = ctx.find_field(buf.name) + props.add(FieldBasePtr(field)) + params.append(KernelParameter(symb.name, symb.get_dtype(), props)) params.sort(key=lambda p: p.name) return params diff --git a/src/pystencils/backend/memory.py b/src/pystencils/backend/memory.py index f45635150848c9ba97ee5c0b42df74d7e497af08..6594cafbdb494c94fa976e66f5a71b33959a81f7 100644 --- a/src/pystencils/backend/memory.py +++ b/src/pystencils/backend/memory.py @@ -5,41 +5,8 @@ from dataclasses import dataclass from ..types import PsType, PsTypeError, deconstify, PsIntegerType from .exceptions import PsInternalCompilerError -from ..field import Field from .constants import PsConstant - - -@dataclass(frozen=True) -class PsSymbolProperty: - """Base class for symbol properties, which can be used to add additional information to symbols""" - - _unique: ClassVar[bool] = False - """Set to `True` in a subclass for property types of which only one instance per symbol is allowed.""" - - -@dataclass(frozen=True) -class FieldShape(PsSymbolProperty): - """Symbol acts as a shape parameter to a field.""" - - field: Field - coordinate: int - - -@dataclass(frozen=True) -class FieldStride(PsSymbolProperty): - """Symbol acts as a stride parameter to a field.""" - - field: Field - coordinate: int - - -@dataclass(frozen=True) -class FieldBasePtr(PsSymbolProperty): - """Symbol acts as a base pointer to a field.""" - - field: Field - - _unique: ClassVar[bool] = True +from .properties import PsSymbolProperty, UniqueSymbolProperty class PsSymbol: @@ -101,7 +68,9 @@ class PsSymbol: def add_property(self, property: PsSymbolProperty): """Attach a property to this symbol""" - if property._unique and not self.get_properties(type(property)) <= {property}: + if isinstance(property, UniqueSymbolProperty) and not self.get_properties( + type(property) + ) <= {property}: raise ValueError( f"Cannot add second instance of unique property {type(property)} to symbol {self._name}." ) @@ -194,7 +163,7 @@ class PsBuffer: @property def strides(self) -> tuple[PsSymbol | PsConstant, ...]: return self._strides - + @property def dim(self) -> int: return len(self._shape) diff --git a/src/pystencils/backend/properties.py b/src/pystencils/backend/properties.py new file mode 100644 index 0000000000000000000000000000000000000000..d377fb3d35d99b59c4f364cc4d066b736bfd9140 --- /dev/null +++ b/src/pystencils/backend/properties.py @@ -0,0 +1,41 @@ +from __future__ import annotations +from dataclasses import dataclass + +from ..field import Field + + +@dataclass(frozen=True) +class PsSymbolProperty: + """Base class for symbol properties, which can be used to add additional information to symbols""" + + +@dataclass(frozen=True) +class UniqueSymbolProperty(PsSymbolProperty): + """Base class for unique properties, of which only one instance may be registered at a time.""" + + +@dataclass(frozen=True) +class FieldShape(PsSymbolProperty): + """Symbol acts as a shape parameter to a field.""" + + field: Field + coordinate: int + + +@dataclass(frozen=True) +class FieldStride(PsSymbolProperty): + """Symbol acts as a stride parameter to a field.""" + + field: Field + coordinate: int + + +@dataclass(frozen=True) +class FieldBasePtr(UniqueSymbolProperty): + """Symbol acts as a base pointer to a field.""" + + field: Field + + +FieldProperty = FieldShape | FieldStride | FieldBasePtr +_FieldProperty = (FieldShape, FieldStride, FieldBasePtr)