From b669db1d8ca5f08550ee7bbfd255f408454536f0 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Sat, 16 Dec 2023 16:28:50 +0100 Subject: [PATCH] fixed a bug in kernel namespaces. extended map_param to multiple RHSs. --- src/pystencilssfg/composer/basic_composer.py | 16 +++++++++++----- src/pystencilssfg/source_components.py | 3 ++- .../source_concepts/source_objects.py | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 492847a..56583f0 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Sequence from abc import ABC, abstractmethod import numpy as np -from pystencils import Field +from pystencils import Field, TypedSymbol from pystencils.astnodes import KernelFunction from ..tree import ( @@ -74,7 +74,7 @@ class SfgBasicComposer: """Returns the kernel namespace of the given name, creating it if it does not exist yet.""" kns = self._ctx.get_kernel_namespace(name) if kns is None: - kns = SfgKernelNamespace(self, name) + kns = SfgKernelNamespace(self._ctx, name) self._ctx.add_kernel_namespace(kns) return kns @@ -198,11 +198,17 @@ class SfgBasicComposer: return SfgDeferredFieldMapping(field, src_object) def map_param( - self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str + self, + lhs: TypedSymbolOrObject, + rhs: TypedSymbolOrObject | Sequence[TypedSymbolOrObject], + mapping: str, ): """Arbitrary parameter mapping: Add a single line of code to define a left-hand - side object from a right-hand side.""" - return SfgStatements(mapping, (lhs,), (rhs,)) + side object from one or multiple right-hand side dependencies.""" + if isinstance(rhs, (TypedSymbol, SrcObject)): + return SfgStatements(mapping, (lhs,), (rhs,)) + else: + return SfgStatements(mapping, (lhs,), rhs) def map_vector(self, lhs_components: Sequence[TypedSymbolOrObject], rhs: SrcVector): """Extracts scalar numerical values from a vector data type. diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py index bf3a514..0b3d8fa 100644 --- a/src/pystencilssfg/source_components.py +++ b/src/pystencilssfg/source_components.py @@ -15,6 +15,7 @@ from .exceptions import SfgException if TYPE_CHECKING: from .tree import SfgCallTreeNode + from .context import SfgContext class SfgEmptyLines: @@ -124,7 +125,7 @@ class SfgKernelNamespace: class SfgKernelHandle: def __init__( self, - ctx, + ctx: SfgContext, name: str, namespace: SfgKernelNamespace, parameters: Sequence[KernelFunction.Parameter], diff --git a/src/pystencilssfg/source_concepts/source_objects.py b/src/pystencilssfg/source_concepts/source_objects.py index 01e98d5..584d20c 100644 --- a/src/pystencilssfg/source_concepts/source_objects.py +++ b/src/pystencilssfg/source_concepts/source_objects.py @@ -54,7 +54,7 @@ class SrcObject: return self.name -TypedSymbolOrObject: TypeAlias = Union[TypedSymbol, SrcObject] +TypedSymbolOrObject: TypeAlias = TypedSymbol | SrcObject class SrcField(SrcObject, ABC): -- GitLab