From a484bce1a48dbf6aa4765f55dae3c00df847a892 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 23 Oct 2024 10:02:54 +0200
Subject: [PATCH] Adapt KernelParameter API for backward-compatibility

---
 .../backend/jit/cpu_extension_module.py       |  2 +-
 src/pystencils/backend/jit/gpu_cupy.py        |  2 +-
 src/pystencils/backend/kernelfunction.py      | 34 ++++++++++++++++---
 3 files changed, 31 insertions(+), 7 deletions(-)

diff --git a/src/pystencils/backend/jit/cpu_extension_module.py b/src/pystencils/backend/jit/cpu_extension_module.py
index dede60cba..d7f644550 100644
--- a/src/pystencils/backend/jit/cpu_extension_module.py
+++ b/src/pystencils/backend/jit/cpu_extension_module.py
@@ -281,7 +281,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
 
     def extract_array_assoc_var(self, param: KernelParameter) -> str:
         if param not in self._array_assoc_var_extractions:
-            field = param.fields.pop()
+            field = param.fields[0]
             buffer = self.extract_field(field)
             code: str | None = None
 
diff --git a/src/pystencils/backend/jit/gpu_cupy.py b/src/pystencils/backend/jit/gpu_cupy.py
index 15f5f6967..7f38d9d43 100644
--- a/src/pystencils/backend/jit/gpu_cupy.py
+++ b/src/pystencils/backend/jit/gpu_cupy.py
@@ -97,7 +97,7 @@ class CupyKernelWrapper(KernelWrapper):
         index_shapes = set()
 
         def check_shape(field_ptr: KernelParameter, arr: cp.ndarray):
-            field = field_ptr.fields.pop()
+            field = field_ptr.fields[0]
 
             if field.has_fixed_shape:
                 expected_shape = tuple(int(s) for s in field.shape)
diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py
index da0b59e8f..9275c55ec 100644
--- a/src/pystencils/backend/kernelfunction.py
+++ b/src/pystencils/backend/kernelfunction.py
@@ -2,6 +2,7 @@ from __future__ import annotations
 
 from warnings import warn
 from typing import Callable, Sequence, Iterable, Any, TYPE_CHECKING
+from itertools import chain
 
 from .._deprecation import _deprecated
 
@@ -42,6 +43,17 @@ class KernelParameter:
         self._properties: frozenset[PsSymbolProperty] = (
             frozenset(properties) if properties is not None else frozenset()
         )
+        self._fields: tuple[Field, ...] = tuple(
+            sorted(
+                set(
+                    p.field  # type: ignore
+                    for p in filter(
+                        lambda p: isinstance(p, _FieldProperty), self._properties
+                    )
+                ),
+                key=lambda f: f.name
+            )
+        )
 
     @property
     def name(self):
@@ -78,23 +90,26 @@ class KernelParameter:
         return TypedSymbol(self.name, self.dtype)
 
     @property
-    def fields(self) -> set[Field]:
+    def fields(self) -> tuple[Field, ...]:
         """Set of fields associated with this parameter."""
-        return set(p.field for p in filter(lambda p: isinstance(p, _FieldProperty), self.properties))  # type: ignore
+        return self._fields
 
     def get_properties(
         self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...]
     ) -> set[PsSymbolProperty]:
         """Retrieve all properties of the given type(s) attached to this parameter"""
         return set(filter(lambda p: isinstance(p, prop_type), self._properties))
-    
+
     @property
     def properties(self) -> frozenset[PsSymbolProperty]:
         return self._properties
 
     @property
     def is_field_parameter(self) -> bool:
-        return bool(self.fields)
+        return bool(self._fields)
+
+    #   Deprecated legacy properties
+    #   These are kept mostly for the legacy waLBerla code generation system
 
     @property
     def is_field_pointer(self) -> bool:
@@ -123,6 +138,15 @@ class KernelParameter:
         )
         return bool(self.get_properties(FieldShape))
 
+    @property
+    def field_name(self) -> str:
+        warn(
+            "`field_name` is deprecated and will be removed in a future version of pystencils. "
+            "Use `param.fields[0].name` instead.",
+            DeprecationWarning,
+        )
+        return self._fields[0].name
+
 
 class KernelFunction:
     """A pystencils kernel function.
@@ -190,7 +214,7 @@ class KernelFunction:
         return self.parameters
 
     def get_fields(self) -> set[Field]:
-        return set.union(*(p.fields for p in self._params))
+        return set(chain.from_iterable(p.fields for p in self._params))
 
     @property
     def fields_accessed(self) -> set[Field]:
-- 
GitLab