Skip to content
Snippets Groups Projects
Commit f1c556e6 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Integrate reduction pointers to parameters.py

parent 777ab888
No related branches found
No related tags found
1 merge request!438Reduction Support
from __future__ import annotations
from warnings import warn
from typing import Sequence, Iterable
from typing import Sequence, Iterable, Optional
from .properties import (
PsSymbolProperty,
_FieldProperty,
FieldShape,
FieldStride,
FieldBasePtr,
FieldBasePtr, ReductionPointerVariable,
)
from ..types import PsType
from ..field import Field
......@@ -39,6 +39,9 @@ class Parameter:
key=lambda f: f.name,
)
)
self._reduction_ptr: Optional[ReductionPointerVariable] = next(
(e for e in self._properties if isinstance(e, ReductionPointerVariable)), None
)
@property
def name(self):
......@@ -79,6 +82,11 @@ class Parameter:
"""Set of fields associated with this parameter."""
return self._fields
@property
def reduction_pointer(self) -> Optional[ReductionPointerVariable]:
"""Reduction pointer associated with this parameter."""
return self._reduction_ptr
def get_properties(
self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...]
) -> set[PsSymbolProperty]:
......@@ -105,6 +113,10 @@ class Parameter:
)
return bool(self.get_properties(FieldBasePtr))
@property
def is_reduction_pointer(self) -> bool:
return bool(self._reduction_ptr)
@property
def is_field_stride(self) -> bool: # pragma: no cover
warn(
......
......@@ -206,6 +206,8 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
self._array_assoc_var_extractions: dict[Parameter, str] = dict()
self._scalar_extractions: dict[Parameter, str] = dict()
self._reduction_ptrs: dict[Parameter, str] = dict()
self._constraint_checks: list[str] = []
self._call: str | None = None
......@@ -265,10 +267,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
return self._array_buffers[field]
def extract_scalar(self, param: Parameter) -> str:
if any(isinstance(e, ReductionPointerVariable) for e in param.properties):
# TODO: implement
pass
elif param not in self._scalar_extractions:
if param not in self._scalar_extractions:
extract_func = self._scalar_extractor(param.dtype)
code = self.TMPL_EXTRACT_SCALAR.format(
name=param.name,
......@@ -279,6 +278,12 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
return param.name
def extract_reduction_ptr(self, param: Parameter) -> str:
if param not in self._reduction_ptrs:
# TODO: implement
pass
return param.name
def extract_array_assoc_var(self, param: Parameter) -> str:
if param not in self._array_assoc_var_extractions:
field = param.fields[0]
......@@ -306,7 +311,9 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
return param.name
def extract_parameter(self, param: Parameter):
if param.is_field_parameter:
if param.is_reduction_pointer:
self.extract_reduction_ptr(param)
elif param.is_field_parameter:
self.extract_array_assoc_var(param)
else:
self.extract_scalar(param)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment