From b669db1d8ca5f08550ee7bbfd255f408454536f0 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sat, 16 Dec 2023 16:28:50 +0100
Subject: [PATCH] fixed a bug in kernel namespaces. extended map_param to
 multiple RHSs.

---
 src/pystencilssfg/composer/basic_composer.py     | 16 +++++++++++-----
 src/pystencilssfg/source_components.py           |  3 ++-
 .../source_concepts/source_objects.py            |  2 +-
 3 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py
index 492847a..56583f0 100644
--- a/src/pystencilssfg/composer/basic_composer.py
+++ b/src/pystencilssfg/composer/basic_composer.py
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Sequence
 from abc import ABC, abstractmethod
 import numpy as np
 
-from pystencils import Field
+from pystencils import Field, TypedSymbol
 from pystencils.astnodes import KernelFunction
 
 from ..tree import (
@@ -74,7 +74,7 @@ class SfgBasicComposer:
         """Returns the kernel namespace of the given name, creating it if it does not exist yet."""
         kns = self._ctx.get_kernel_namespace(name)
         if kns is None:
-            kns = SfgKernelNamespace(self, name)
+            kns = SfgKernelNamespace(self._ctx, name)
             self._ctx.add_kernel_namespace(kns)
 
         return kns
@@ -198,11 +198,17 @@ class SfgBasicComposer:
         return SfgDeferredFieldMapping(field, src_object)
 
     def map_param(
-        self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str
+        self,
+        lhs: TypedSymbolOrObject,
+        rhs: TypedSymbolOrObject | Sequence[TypedSymbolOrObject],
+        mapping: str,
     ):
         """Arbitrary parameter mapping: Add a single line of code to define a left-hand
-        side object from a right-hand side."""
-        return SfgStatements(mapping, (lhs,), (rhs,))
+        side object from one or multiple right-hand side dependencies."""
+        if isinstance(rhs, (TypedSymbol, SrcObject)):
+            return SfgStatements(mapping, (lhs,), (rhs,))
+        else:
+            return SfgStatements(mapping, (lhs,), rhs)
 
     def map_vector(self, lhs_components: Sequence[TypedSymbolOrObject], rhs: SrcVector):
         """Extracts scalar numerical values from a vector data type.
diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py
index bf3a514..0b3d8fa 100644
--- a/src/pystencilssfg/source_components.py
+++ b/src/pystencilssfg/source_components.py
@@ -15,6 +15,7 @@ from .exceptions import SfgException
 
 if TYPE_CHECKING:
     from .tree import SfgCallTreeNode
+    from .context import SfgContext
 
 
 class SfgEmptyLines:
@@ -124,7 +125,7 @@ class SfgKernelNamespace:
 class SfgKernelHandle:
     def __init__(
         self,
-        ctx,
+        ctx: SfgContext,
         name: str,
         namespace: SfgKernelNamespace,
         parameters: Sequence[KernelFunction.Parameter],
diff --git a/src/pystencilssfg/source_concepts/source_objects.py b/src/pystencilssfg/source_concepts/source_objects.py
index 01e98d5..584d20c 100644
--- a/src/pystencilssfg/source_concepts/source_objects.py
+++ b/src/pystencilssfg/source_concepts/source_objects.py
@@ -54,7 +54,7 @@ class SrcObject:
         return self.name
 
 
-TypedSymbolOrObject: TypeAlias = Union[TypedSymbol, SrcObject]
+TypedSymbolOrObject: TypeAlias = TypedSymbol | SrcObject
 
 
 class SrcField(SrcObject, ABC):
-- 
GitLab