From a484bce1a48dbf6aa4765f55dae3c00df847a892 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 23 Oct 2024 10:02:54 +0200 Subject: [PATCH] Adapt KernelParameter API for backward-compatibility --- .../backend/jit/cpu_extension_module.py | 2 +- src/pystencils/backend/jit/gpu_cupy.py | 2 +- src/pystencils/backend/kernelfunction.py | 34 ++++++++++++++++--- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/pystencils/backend/jit/cpu_extension_module.py b/src/pystencils/backend/jit/cpu_extension_module.py index dede60cba..d7f644550 100644 --- a/src/pystencils/backend/jit/cpu_extension_module.py +++ b/src/pystencils/backend/jit/cpu_extension_module.py @@ -281,7 +281,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ def extract_array_assoc_var(self, param: KernelParameter) -> str: if param not in self._array_assoc_var_extractions: - field = param.fields.pop() + field = param.fields[0] buffer = self.extract_field(field) code: str | None = None diff --git a/src/pystencils/backend/jit/gpu_cupy.py b/src/pystencils/backend/jit/gpu_cupy.py index 15f5f6967..7f38d9d43 100644 --- a/src/pystencils/backend/jit/gpu_cupy.py +++ b/src/pystencils/backend/jit/gpu_cupy.py @@ -97,7 +97,7 @@ class CupyKernelWrapper(KernelWrapper): index_shapes = set() def check_shape(field_ptr: KernelParameter, arr: cp.ndarray): - field = field_ptr.fields.pop() + field = field_ptr.fields[0] if field.has_fixed_shape: expected_shape = tuple(int(s) for s in field.shape) diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py index da0b59e8f..9275c55ec 100644 --- a/src/pystencils/backend/kernelfunction.py +++ b/src/pystencils/backend/kernelfunction.py @@ -2,6 +2,7 @@ from __future__ import annotations from warnings import warn from typing import Callable, Sequence, Iterable, Any, TYPE_CHECKING +from itertools import chain from .._deprecation import _deprecated @@ -42,6 +43,17 @@ class KernelParameter: self._properties: frozenset[PsSymbolProperty] = ( frozenset(properties) if properties is not None else frozenset() ) + self._fields: tuple[Field, ...] = tuple( + sorted( + set( + p.field # type: ignore + for p in filter( + lambda p: isinstance(p, _FieldProperty), self._properties + ) + ), + key=lambda f: f.name + ) + ) @property def name(self): @@ -78,23 +90,26 @@ class KernelParameter: return TypedSymbol(self.name, self.dtype) @property - def fields(self) -> set[Field]: + def fields(self) -> tuple[Field, ...]: """Set of fields associated with this parameter.""" - return set(p.field for p in filter(lambda p: isinstance(p, _FieldProperty), self.properties)) # type: ignore + return self._fields def get_properties( self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...] ) -> set[PsSymbolProperty]: """Retrieve all properties of the given type(s) attached to this parameter""" return set(filter(lambda p: isinstance(p, prop_type), self._properties)) - + @property def properties(self) -> frozenset[PsSymbolProperty]: return self._properties @property def is_field_parameter(self) -> bool: - return bool(self.fields) + return bool(self._fields) + + # Deprecated legacy properties + # These are kept mostly for the legacy waLBerla code generation system @property def is_field_pointer(self) -> bool: @@ -123,6 +138,15 @@ class KernelParameter: ) return bool(self.get_properties(FieldShape)) + @property + def field_name(self) -> str: + warn( + "`field_name` is deprecated and will be removed in a future version of pystencils. " + "Use `param.fields[0].name` instead.", + DeprecationWarning, + ) + return self._fields[0].name + class KernelFunction: """A pystencils kernel function. @@ -190,7 +214,7 @@ class KernelFunction: return self.parameters def get_fields(self) -> set[Field]: - return set.union(*(p.fields for p in self._params)) + return set(chain.from_iterable(p.fields for p in self._params)) @property def fields_accessed(self) -> set[Field]: -- GitLab