Skip to content
Snippets Groups Projects
Commit a484bce1 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Adapt KernelParameter API for backward-compatibility

parent ce3fb3e8
No related branches found
No related tags found
1 merge request!421Refactor Field Modelling
Pipeline #69786 passed
......@@ -281,7 +281,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
def extract_array_assoc_var(self, param: KernelParameter) -> str:
if param not in self._array_assoc_var_extractions:
field = param.fields.pop()
field = param.fields[0]
buffer = self.extract_field(field)
code: str | None = None
......
......@@ -97,7 +97,7 @@ class CupyKernelWrapper(KernelWrapper):
index_shapes = set()
def check_shape(field_ptr: KernelParameter, arr: cp.ndarray):
field = field_ptr.fields.pop()
field = field_ptr.fields[0]
if field.has_fixed_shape:
expected_shape = tuple(int(s) for s in field.shape)
......
......@@ -2,6 +2,7 @@ from __future__ import annotations
from warnings import warn
from typing import Callable, Sequence, Iterable, Any, TYPE_CHECKING
from itertools import chain
from .._deprecation import _deprecated
......@@ -42,6 +43,17 @@ class KernelParameter:
self._properties: frozenset[PsSymbolProperty] = (
frozenset(properties) if properties is not None else frozenset()
)
self._fields: tuple[Field, ...] = tuple(
sorted(
set(
p.field # type: ignore
for p in filter(
lambda p: isinstance(p, _FieldProperty), self._properties
)
),
key=lambda f: f.name
)
)
@property
def name(self):
......@@ -78,23 +90,26 @@ class KernelParameter:
return TypedSymbol(self.name, self.dtype)
@property
def fields(self) -> set[Field]:
def fields(self) -> tuple[Field, ...]:
"""Set of fields associated with this parameter."""
return set(p.field for p in filter(lambda p: isinstance(p, _FieldProperty), self.properties)) # type: ignore
return self._fields
def get_properties(
self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...]
) -> set[PsSymbolProperty]:
"""Retrieve all properties of the given type(s) attached to this parameter"""
return set(filter(lambda p: isinstance(p, prop_type), self._properties))
@property
def properties(self) -> frozenset[PsSymbolProperty]:
return self._properties
@property
def is_field_parameter(self) -> bool:
return bool(self.fields)
return bool(self._fields)
# Deprecated legacy properties
# These are kept mostly for the legacy waLBerla code generation system
@property
def is_field_pointer(self) -> bool:
......@@ -123,6 +138,15 @@ class KernelParameter:
)
return bool(self.get_properties(FieldShape))
@property
def field_name(self) -> str:
warn(
"`field_name` is deprecated and will be removed in a future version of pystencils. "
"Use `param.fields[0].name` instead.",
DeprecationWarning,
)
return self._fields[0].name
class KernelFunction:
"""A pystencils kernel function.
......@@ -190,7 +214,7 @@ class KernelFunction:
return self.parameters
def get_fields(self) -> set[Field]:
return set.union(*(p.fields for p in self._params))
return set(chain.from_iterable(p.fields for p in self._params))
@property
def fields_accessed(self) -> set[Field]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment