Skip to content
Snippets Groups Projects
Commit b669db1d authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fixed a bug in kernel namespaces. extended map_param to multiple RHSs.

parent 1934248a
Branches
Tags
No related merge requests found
......@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Sequence
from abc import ABC, abstractmethod
import numpy as np
from pystencils import Field
from pystencils import Field, TypedSymbol
from pystencils.astnodes import KernelFunction
from ..tree import (
......@@ -74,7 +74,7 @@ class SfgBasicComposer:
"""Returns the kernel namespace of the given name, creating it if it does not exist yet."""
kns = self._ctx.get_kernel_namespace(name)
if kns is None:
kns = SfgKernelNamespace(self, name)
kns = SfgKernelNamespace(self._ctx, name)
self._ctx.add_kernel_namespace(kns)
return kns
......@@ -198,11 +198,17 @@ class SfgBasicComposer:
return SfgDeferredFieldMapping(field, src_object)
def map_param(
self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str
self,
lhs: TypedSymbolOrObject,
rhs: TypedSymbolOrObject | Sequence[TypedSymbolOrObject],
mapping: str,
):
"""Arbitrary parameter mapping: Add a single line of code to define a left-hand
side object from a right-hand side."""
return SfgStatements(mapping, (lhs,), (rhs,))
side object from one or multiple right-hand side dependencies."""
if isinstance(rhs, (TypedSymbol, SrcObject)):
return SfgStatements(mapping, (lhs,), (rhs,))
else:
return SfgStatements(mapping, (lhs,), rhs)
def map_vector(self, lhs_components: Sequence[TypedSymbolOrObject], rhs: SrcVector):
"""Extracts scalar numerical values from a vector data type.
......
......@@ -15,6 +15,7 @@ from .exceptions import SfgException
if TYPE_CHECKING:
from .tree import SfgCallTreeNode
from .context import SfgContext
class SfgEmptyLines:
......@@ -124,7 +125,7 @@ class SfgKernelNamespace:
class SfgKernelHandle:
def __init__(
self,
ctx,
ctx: SfgContext,
name: str,
namespace: SfgKernelNamespace,
parameters: Sequence[KernelFunction.Parameter],
......
......@@ -54,7 +54,7 @@ class SrcObject:
return self.name
TypedSymbolOrObject: TypeAlias = Union[TypedSymbol, SrcObject]
TypedSymbolOrObject: TypeAlias = TypedSymbol | SrcObject
class SrcField(SrcObject, ABC):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment