Skip to content
Snippets Groups Projects
Commit 841e40eb authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Added vector extraction to composer

parent 87846725
Branches
Tags
No related merge requests found
Pipeline #57673 passed
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional, Sequence
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pystencils import Field from pystencils import Field
...@@ -9,7 +9,7 @@ from .tree import SfgCallTreeNode, SfgKernelCallNode, SfgStatements, SfgSequence ...@@ -9,7 +9,7 @@ from .tree import SfgCallTreeNode, SfgKernelCallNode, SfgStatements, SfgSequence
from .tree.deferred_nodes import SfgDeferredFieldMapping from .tree.deferred_nodes import SfgDeferredFieldMapping
from .tree.conditional import SfgCondition, SfgCustomCondition, SfgBranch from .tree.conditional import SfgCondition, SfgCustomCondition, SfgBranch
from .source_components import SfgFunction, SfgHeaderInclude, SfgKernelNamespace, SfgKernelHandle from .source_components import SfgFunction, SfgHeaderInclude, SfgKernelNamespace, SfgKernelHandle
from .source_concepts import SrcField, TypedSymbolOrObject from .source_concepts import SrcField, TypedSymbolOrObject, SrcVector
if TYPE_CHECKING: if TYPE_CHECKING:
from .context import SfgContext from .context import SfgContext
...@@ -92,6 +92,11 @@ class SfgComposer: ...@@ -92,6 +92,11 @@ class SfgComposer:
def map_param(self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str): def map_param(self, lhs: TypedSymbolOrObject, rhs: TypedSymbolOrObject, mapping: str):
return SfgStatements(mapping, (lhs,), (rhs,)) return SfgStatements(mapping, (lhs,), (rhs,))
def map_vector(self, lhs_components: Sequence[TypedSymbolOrObject], rhs: SrcVector):
return make_sequence(*(
rhs.extract_component(dest, coord) for coord, dest in enumerate(lhs_components)
))
class SfgNodeBuilder(ABC): class SfgNodeBuilder(ABC):
@abstractmethod @abstractmethod
......
...@@ -75,7 +75,7 @@ class SfgKernelNamespace: ...@@ -75,7 +75,7 @@ class SfgKernelNamespace:
def create(self, assignments, name: str | None = None, config: CreateKernelConfig | None = None): def create(self, assignments, name: str | None = None, config: CreateKernelConfig | None = None):
if config is None: if config is None:
config = CreateKernelConfig() config = CreateKernelConfig()
if name is not None: if name is not None:
if name in self._asts: if name in self._asts:
raise ValueError(f"Duplicate ASTs: An AST with name {name} already exists in namespace {self._name}") raise ValueError(f"Duplicate ASTs: An AST with name {name} already exists in namespace {self._name}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment