diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 492847a0554e493c2ca65c3f79b0fe0e2869b220..56583f063f3aba0623fa4b8b5f84836a93795e08 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Sequence from abc import ABC, abstractmethod import numpy as np -from pystencils import Field +from pystencils import Field, TypedSymbol from pystencils.astnodes import KernelFunction from ..tree import ( @@ -74,7 +74,7 @@ class SfgBasicComposer: """Returns the kernel namespace of the given name, creating it if it does not exist yet.""" kns = self._ctx.get_kernel_namespace(name) if kns is None: - kns = SfgKernelNamespace(self, name) + kns = SfgKernelNamespace(self._ctx, name) self._ctx.add_kernel_namespace(kns) return kns @@ -198,11 +198,17 @@ class SfgBasicComposer: return SfgDeferredFieldMapping(field, src_object) def map_param( - self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str + self, + lhs: TypedSymbolOrObject, + rhs: TypedSymbolOrObject | Sequence[TypedSymbolOrObject], + mapping: str, ): """Arbitrary parameter mapping: Add a single line of code to define a left-hand - side object from a right-hand side.""" - return SfgStatements(mapping, (lhs,), (rhs,)) + side object from one or multiple right-hand side dependencies.""" + if isinstance(rhs, (TypedSymbol, SrcObject)): + return SfgStatements(mapping, (lhs,), (rhs,)) + else: + return SfgStatements(mapping, (lhs,), rhs) def map_vector(self, lhs_components: Sequence[TypedSymbolOrObject], rhs: SrcVector): """Extracts scalar numerical values from a vector data type. diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py index bf3a5147b47b3dd62bad99e450c1c89335a9cc49..0b3d8fab1572b9954aa75fe4602f43797e2163a6 100644 --- a/src/pystencilssfg/source_components.py +++ b/src/pystencilssfg/source_components.py @@ -15,6 +15,7 @@ from .exceptions import SfgException if TYPE_CHECKING: from .tree import SfgCallTreeNode + from .context import SfgContext class SfgEmptyLines: @@ -124,7 +125,7 @@ class SfgKernelNamespace: class SfgKernelHandle: def __init__( self, - ctx, + ctx: SfgContext, name: str, namespace: SfgKernelNamespace, parameters: Sequence[KernelFunction.Parameter], diff --git a/src/pystencilssfg/source_concepts/source_objects.py b/src/pystencilssfg/source_concepts/source_objects.py index 01e98d558b0ef440dcaed1980b0eb367a7a81843..584d20ca700f9e4835be7f240ff31ffd1c029905 100644 --- a/src/pystencilssfg/source_concepts/source_objects.py +++ b/src/pystencilssfg/source_concepts/source_objects.py @@ -54,7 +54,7 @@ class SrcObject: return self.name -TypedSymbolOrObject: TypeAlias = Union[TypedSymbol, SrcObject] +TypedSymbolOrObject: TypeAlias = TypedSymbol | SrcObject class SrcField(SrcObject, ABC):