-
Richard Angersbach authoredRichard Angersbach authored
parameters.py 4.47 KiB
from __future__ import annotations
from warnings import warn
from typing import Sequence, Iterable, Optional
from .properties import (
PsSymbolProperty,
_FieldProperty,
FieldShape,
FieldStride,
FieldBasePtr, ReductionPointerVariable,
)
from ..types import PsType
from ..field import Field
from ..sympyextensions import TypedSymbol
class Parameter:
"""Parameter to an output object of the code generator."""
__match_args__ = ("name", "dtype", "properties")
def __init__(
self, name: str, dtype: PsType, properties: Iterable[PsSymbolProperty] = ()
):
self._name = name
self._dtype = dtype
self._properties: frozenset[PsSymbolProperty] = (
frozenset(properties) if properties is not None else frozenset()
)
self._fields: tuple[Field, ...] = tuple(
sorted(
set(
p.field # type: ignore
for p in filter(
lambda p: isinstance(p, _FieldProperty), self._properties
)
),
key=lambda f: f.name,
)
)
self._reduction_ptr: Optional[ReductionPointerVariable] = next(
(e for e in self._properties if isinstance(e, ReductionPointerVariable)), None
)
@property
def name(self):
return self._name
@property
def dtype(self):
return self._dtype
def _hashable_contents(self):
return (self._name, self._dtype, self._properties)
# TODO: Need?
def __hash__(self) -> int:
return hash(self._hashable_contents())
def __eq__(self, other: object) -> bool:
if not isinstance(other, Parameter):
return False
return (
type(self) is type(other)
and self._hashable_contents() == other._hashable_contents()
)
def __str__(self) -> str:
return self._name
def __repr__(self) -> str:
return f"{type(self).__name__}(name = {self._name}, dtype = {self._dtype})"
@property
def symbol(self) -> TypedSymbol:
return TypedSymbol(self.name, self.dtype)
@property
def fields(self) -> Sequence[Field]:
"""Set of fields associated with this parameter."""
return self._fields
@property
def reduction_pointer(self) -> Optional[ReductionPointerVariable]:
"""Reduction pointer associated with this parameter."""
return self._reduction_ptr
def get_properties(
self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...]
) -> set[PsSymbolProperty]:
"""Retrieve all properties of the given type(s) attached to this parameter"""
return set(filter(lambda p: isinstance(p, prop_type), self._properties))
@property
def properties(self) -> frozenset[PsSymbolProperty]:
return self._properties
@property
def is_field_parameter(self) -> bool:
return bool(self._fields)
# Deprecated legacy properties
# These are kept mostly for the legacy waLBerla code generation system
@property
def is_field_pointer(self) -> bool: # pragma: no cover
warn(
"`is_field_pointer` is deprecated and will be removed in a future version of pystencils. "
"Use `param.get_properties(FieldBasePtr)` instead.",
DeprecationWarning,
)
return bool(self.get_properties(FieldBasePtr))
@property
def is_reduction_pointer(self) -> bool:
return bool(self._reduction_ptr)
@property
def is_field_stride(self) -> bool: # pragma: no cover
warn(
"`is_field_stride` is deprecated and will be removed in a future version of pystencils. "
"Use `param.get_properties(FieldStride)` instead.",
DeprecationWarning,
)
return bool(self.get_properties(FieldStride))
@property
def is_field_shape(self) -> bool: # pragma: no cover
warn(
"`is_field_shape` is deprecated and will be removed in a future version of pystencils. "
"Use `param.get_properties(FieldShape)` instead.",
DeprecationWarning,
)
return bool(self.get_properties(FieldShape))
@property
def field_name(self) -> str: # pragma: no cover
warn(
"`field_name` is deprecated and will be removed in a future version of pystencils. "
"Use `param.fields[0].name` instead.",
DeprecationWarning,
)
return self._fields[0].name