Skip to content
Snippets Groups Projects

Reduction Support

Open Richard Angersbach requested to merge rangersbach/reductions into v2.0-dev
Viewing commit 4e748308
Show latest version
2 files
+ 26
7
Preferences
Compare changes
Files
2
from __future__ import annotations
from warnings import warn
from typing import Sequence, Iterable
from typing import Sequence, Iterable, Optional
from .properties import (
PsSymbolProperty,
_FieldProperty,
FieldShape,
FieldStride,
FieldBasePtr,
FieldBasePtr, ReductionPointerVariable,
)
from ..types import PsType
from ..field import Field
@@ -39,6 +39,9 @@ class Parameter:
key=lambda f: f.name,
)
)
self._reduction_ptr: Optional[ReductionPointerVariable] = next(
(e for e in self._properties if isinstance(e, ReductionPointerVariable)), None
)
@property
def name(self):
@@ -79,6 +82,11 @@ class Parameter:
"""Set of fields associated with this parameter."""
return self._fields
@property
def reduction_pointer(self) -> Optional[ReductionPointerVariable]:
"""Reduction pointer associated with this parameter."""
return self._reduction_ptr
def get_properties(
self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...]
) -> set[PsSymbolProperty]:
@@ -105,6 +113,10 @@ class Parameter:
)
return bool(self.get_properties(FieldBasePtr))
@property
def is_reduction_pointer(self) -> bool:
return bool(self._reduction_ptr)
@property
def is_field_stride(self) -> bool: # pragma: no cover
warn(