Skip to content
Snippets Groups Projects
Commit a484bce1 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Adapt KernelParameter API for backward-compatibility

parent ce3fb3e8
No related branches found
No related tags found
1 merge request!421Refactor Field Modelling
Pipeline #69786 passed
...@@ -281,7 +281,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ ...@@ -281,7 +281,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
def extract_array_assoc_var(self, param: KernelParameter) -> str: def extract_array_assoc_var(self, param: KernelParameter) -> str:
if param not in self._array_assoc_var_extractions: if param not in self._array_assoc_var_extractions:
field = param.fields.pop() field = param.fields[0]
buffer = self.extract_field(field) buffer = self.extract_field(field)
code: str | None = None code: str | None = None
......
...@@ -97,7 +97,7 @@ class CupyKernelWrapper(KernelWrapper): ...@@ -97,7 +97,7 @@ class CupyKernelWrapper(KernelWrapper):
index_shapes = set() index_shapes = set()
def check_shape(field_ptr: KernelParameter, arr: cp.ndarray): def check_shape(field_ptr: KernelParameter, arr: cp.ndarray):
field = field_ptr.fields.pop() field = field_ptr.fields[0]
if field.has_fixed_shape: if field.has_fixed_shape:
expected_shape = tuple(int(s) for s in field.shape) expected_shape = tuple(int(s) for s in field.shape)
......
...@@ -2,6 +2,7 @@ from __future__ import annotations ...@@ -2,6 +2,7 @@ from __future__ import annotations
from warnings import warn from warnings import warn
from typing import Callable, Sequence, Iterable, Any, TYPE_CHECKING from typing import Callable, Sequence, Iterable, Any, TYPE_CHECKING
from itertools import chain
from .._deprecation import _deprecated from .._deprecation import _deprecated
...@@ -42,6 +43,17 @@ class KernelParameter: ...@@ -42,6 +43,17 @@ class KernelParameter:
self._properties: frozenset[PsSymbolProperty] = ( self._properties: frozenset[PsSymbolProperty] = (
frozenset(properties) if properties is not None else frozenset() frozenset(properties) if properties is not None else frozenset()
) )
self._fields: tuple[Field, ...] = tuple(
sorted(
set(
p.field # type: ignore
for p in filter(
lambda p: isinstance(p, _FieldProperty), self._properties
)
),
key=lambda f: f.name
)
)
@property @property
def name(self): def name(self):
...@@ -78,23 +90,26 @@ class KernelParameter: ...@@ -78,23 +90,26 @@ class KernelParameter:
return TypedSymbol(self.name, self.dtype) return TypedSymbol(self.name, self.dtype)
@property @property
def fields(self) -> set[Field]: def fields(self) -> tuple[Field, ...]:
"""Set of fields associated with this parameter.""" """Set of fields associated with this parameter."""
return set(p.field for p in filter(lambda p: isinstance(p, _FieldProperty), self.properties)) # type: ignore return self._fields
def get_properties( def get_properties(
self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...] self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...]
) -> set[PsSymbolProperty]: ) -> set[PsSymbolProperty]:
"""Retrieve all properties of the given type(s) attached to this parameter""" """Retrieve all properties of the given type(s) attached to this parameter"""
return set(filter(lambda p: isinstance(p, prop_type), self._properties)) return set(filter(lambda p: isinstance(p, prop_type), self._properties))
@property @property
def properties(self) -> frozenset[PsSymbolProperty]: def properties(self) -> frozenset[PsSymbolProperty]:
return self._properties return self._properties
@property @property
def is_field_parameter(self) -> bool: def is_field_parameter(self) -> bool:
return bool(self.fields) return bool(self._fields)
# Deprecated legacy properties
# These are kept mostly for the legacy waLBerla code generation system
@property @property
def is_field_pointer(self) -> bool: def is_field_pointer(self) -> bool:
...@@ -123,6 +138,15 @@ class KernelParameter: ...@@ -123,6 +138,15 @@ class KernelParameter:
) )
return bool(self.get_properties(FieldShape)) return bool(self.get_properties(FieldShape))
@property
def field_name(self) -> str:
warn(
"`field_name` is deprecated and will be removed in a future version of pystencils. "
"Use `param.fields[0].name` instead.",
DeprecationWarning,
)
return self._fields[0].name
class KernelFunction: class KernelFunction:
"""A pystencils kernel function. """A pystencils kernel function.
...@@ -190,7 +214,7 @@ class KernelFunction: ...@@ -190,7 +214,7 @@ class KernelFunction:
return self.parameters return self.parameters
def get_fields(self) -> set[Field]: def get_fields(self) -> set[Field]:
return set.union(*(p.fields for p in self._params)) return set(chain.from_iterable(p.fields for p in self._params))
@property @property
def fields_accessed(self) -> set[Field]: def fields_accessed(self) -> set[Field]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment