Skip to content
Snippets Groups Projects
Commit 89b6f68b authored by Frederik Hennig's avatar Frederik Hennig
Browse files

baby steps toward field parameter mapping

parent c13f1095
No related branches found
No related tags found
No related merge requests found
...@@ -32,12 +32,21 @@ class SfgKernelNamespace: ...@@ -32,12 +32,21 @@ class SfgKernelNamespace:
class SfgKernelHandle: class SfgKernelHandle:
def __init__(self, ctx, name: str, namespace: SfgKernelNamespace, parameters): def __init__(self, ctx, name: str, namespace: SfgKernelNamespace, parameters: Sequence[KernelFunction.Parameter]):
self._ctx = ctx self._ctx = ctx
self._name = name self._name = name
self._namespace = namespace self._namespace = namespace
self._parameters = parameters self._parameters = parameters
self._scalar_params = set()
self._fields = set()
for param in self._parameters:
if param.is_field_parameter:
self._fields |= set(param.fields)
else:
self._scalar_params.add(param.symbol)
@property @property
def kernel_name(self): def kernel_name(self):
return self._name return self._name
...@@ -53,4 +62,12 @@ class SfgKernelHandle: ...@@ -53,4 +62,12 @@ class SfgKernelHandle:
@property @property
def parameters(self): def parameters(self):
return self._parameters return self._parameters
@property
def scalar_parameters(self):
return self._scalar_params
@property
def fields(self):
return self.fields
\ No newline at end of file
...@@ -5,16 +5,19 @@ if TYPE_CHECKING: ...@@ -5,16 +5,19 @@ if TYPE_CHECKING:
from .context import SfgContext from .context import SfgContext
from .tree import SfgCallTreeNode, SfgSequence from .tree import SfgCallTreeNode, SfgSequence
from .tree.visitors import ParameterCollector from .tree.visitors import FlattenSequences, ParameterCollector
class SfgFunction: class SfgFunction:
def __init__(self, ctx: SfgContext, name: str, tree: SfgCallTreeNode): def __init__(self, ctx: SfgContext, name: str, tree: SfgCallTreeNode):
self._ctx = ctx self._ctx = ctx
self._name = name self._name = name
self._tree = tree self._tree = tree
flattener = FlattenSequences()
flattener.visit(self._tree)
param_collector = ParameterCollector() param_collector = ParameterCollector()
self._parameters = param_collector.visit(tree) self._parameters = param_collector.visit(self._tree)
@property @property
def name(self): def name(self):
......
from abc import ABC, abstractmethod
from .source_concepts import SrcObject, SrcMemberAccess
class SrcContiguousContainer(SrcObject):
def __init__(self, src_type, identifier: Optional[str]):
super().__init__(src_type, identifier)
@abstractmethod
def ptr(self) -> SrcMemberAccess:
pass
@abstractmethod
def size(self, dimension: int) -> SrcMemberAccess:
pass
@abstractmethod
def stride(self, dimension: int) -> SrcMemberAccess:
pass
from typing import Optional
from ..source_concepts import SrcMemberAccess
from ..containers import SrcContiguousContainer
class std_mdspan(SrcContiguousContainer):
def __init__(self, identifer: str):
super().__init__("std::mdspan", identifier)
def ptr(self):
return SrcMemberAccess(self, f"{self._identifier}.data_handle()")
def size(self, dimension: int):
return SrcMemberAccess(self, f"{self._identifier}.extents().extent({dimension})")
def stride(self, dimension: int):
return SrcMemberAccess(self, f"{self._identifier}.stride({dimension})")
from typing import Optional
from abc import ABC, abstractmethod
from pystencils import TypedSymbol
class SrcClass:
def __init__(self):
pass
class SrcObject(ABC):
def __init__(self, src_type, identifier: Optional[str]):
self._src_type = src_type
self._identifier = identifier
@property
def _sfg_symbol(self):
return TypedSymbol(self._identifier, self._src_type)
class SrcMemberAccess():
def __init__(self, obj: SrcObject, code_string: str):
self._obj = obj
self._code_string = code_string
def _sfg_code_string():
return self._code_string
...@@ -41,11 +41,24 @@ class SfgCallTreeLeaf(SfgCallTreeNode, ABC): ...@@ -41,11 +41,24 @@ class SfgCallTreeLeaf(SfgCallTreeNode, ABC):
def required_symbols(self) -> set(TypedSymbol): def required_symbols(self) -> set(TypedSymbol):
pass pass
class SfgParameterDefinition(SfgCallTreeLeaf):
def __init__(self, defined_param: TypedSymbol, required_params: Set[TypedSymbol], code_string: str):
self._defined_param = defined_param
self._required_params = required_params
self._code_string = code_string
@property @property
@abstractmethod def defined_symbol(self) -> TypedSymbol:
def defined_symbols(self) -> set(TypedSymbol): return self._defined_param
pass
@property
def required_symbols(self) -> set(TypedSymbol):
return self._required_params
def get_code(self):
return self._code_string
class SfgCustomStatement(SfgCallTreeLeaf): class SfgCustomStatement(SfgCallTreeLeaf):
def __init__(self, statement: str): def __init__(self, statement: str):
...@@ -54,9 +67,6 @@ class SfgCustomStatement(SfgCallTreeLeaf): ...@@ -54,9 +67,6 @@ class SfgCustomStatement(SfgCallTreeLeaf):
def required_symbols(self) -> set(TypedSymbol): def required_symbols(self) -> set(TypedSymbol):
return set() return set()
def defined_symbols(self) -> set(TypedSymbol):
return set()
def get_code(self, ctx: SfgContext) -> str: def get_code(self, ctx: SfgContext) -> str:
return self._statement return self._statement
...@@ -96,10 +106,6 @@ class SfgKernelCallNode(SfgCallTreeLeaf): ...@@ -96,10 +106,6 @@ class SfgKernelCallNode(SfgCallTreeLeaf):
def required_symbols(self) -> set(TypedSymbol): def required_symbols(self) -> set(TypedSymbol):
return set(p.symbol for p in self._kernel_handle.parameters) return set(p.symbol for p in self._kernel_handle.parameters)
@property
def defined_symbols(self) -> set(TypedSymbol):
return set()
def get_code(self, ctx: SfgContext) -> str: def get_code(self, ctx: SfgContext) -> str:
ast_params = self._kernel_handle.parameters ast_params = self._kernel_handle.parameters
fnc_name = self._kernel_handle.fully_qualified_name fnc_name = self._kernel_handle.fully_qualified_name
......
...@@ -5,13 +5,14 @@ if TYPE_CHECKING: ...@@ -5,13 +5,14 @@ if TYPE_CHECKING:
from ..context import SfgContext from ..context import SfgContext
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pystencils import Field
from .basic_nodes import SfgCallTreeNode, SfgSequence, SfgBlock, SfgCustomStatement from .basic_nodes import SfgCallTreeNode, SfgSequence, SfgBlock, SfgCustomStatement
from .conditional import SfgCondition, SfgCustomCondition, SfgBranch from .conditional import SfgCondition, SfgCustomCondition, SfgBranch
from ..source_concepts.containers import SrcContiguousContainer
class SfgNodeBuilder(ABC): class SfgNodeBuilder(ABC):
def __init__(self, ctx: SfgContext) -> None:
self._ctx = ctx
@abstractmethod @abstractmethod
def resolve(self) -> SfgCallTreeNode: def resolve(self) -> SfgCallTreeNode:
pass pass
...@@ -40,8 +41,8 @@ class SfgSequencer: ...@@ -40,8 +41,8 @@ class SfgSequencer:
class SfgBranchBuilder(SfgNodeBuilder): class SfgBranchBuilder(SfgNodeBuilder):
def __init__(self, ctx: SfgContext) -> None: def __init__(self, ctx: SfgContext):
super().__init__(ctx) self._ctx = ctx
self._phase = 0 self._phase = 0
self._cond = None self._cond = None
...@@ -67,7 +68,7 @@ class SfgBranchBuilder(SfgNodeBuilder): ...@@ -67,7 +68,7 @@ class SfgBranchBuilder(SfgNodeBuilder):
self._branch_true = self._ctx.seq(*args) self._branch_true = self._ctx.seq(*args)
case 2: # Else-branch case 2: # Else-branch
self._branch_false = self._ctx.seq(*args) self._branch_false = self._ctx.seq(*args)
case _: # There's not third branch! case _: # There's no third branch!
raise TypeError("Branch construct already complete.") raise TypeError("Branch construct already complete.")
self._phase += 1 self._phase += 1
...@@ -77,4 +78,14 @@ class SfgBranchBuilder(SfgNodeBuilder): ...@@ -77,4 +78,14 @@ class SfgBranchBuilder(SfgNodeBuilder):
def resolve(self) -> SfgCallTreeNode: def resolve(self) -> SfgCallTreeNode:
return SfgBranch(self._cond, self._branch_true, self._branch_false) return SfgBranch(self._cond, self._branch_true, self._branch_false)
\ No newline at end of file class SfgFieldMappingBuilder(SfgNodeBuilder):
def __init__(self, ctx: SfgContext):
super().__init__(ctx)
self._field = None
self._container = None
def __call__(self, field: Field, container: SrcContiguousContainer):
self._field = field
self._container = container
\ No newline at end of file
...@@ -18,9 +18,6 @@ class SfgCustomCondition(SfgCondition): ...@@ -18,9 +18,6 @@ class SfgCustomCondition(SfgCondition):
def required_symbols(self) -> set(TypedSymbol): def required_symbols(self) -> set(TypedSymbol):
return set() return set()
def defined_symbols(self) -> set(TypedSymbol):
return set()
def get_code(self, ctx: SfgContext) -> str: def get_code(self, ctx: SfgContext) -> str:
return self._cond_text return self._cond_text
......
...@@ -3,10 +3,41 @@ from functools import reduce ...@@ -3,10 +3,41 @@ from functools import reduce
from pystencils.typing import TypedSymbol from pystencils.typing import TypedSymbol
from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgParameterDefinition
class FlattenSequences():
"""Flattens any nested sequences occuring in a kernel call tree."""
def visit(self, node: SfgCallTreeNode) -> None:
if isinstance(node, SfgSequence):
return self._visit_SfgSequence(node)
else:
for c in node.children:
self.visit(c)
def _visit_SfgSequence(self, sequence: SfgSequence) -> None:
children_flattened = []
def flatten(seq: SfgSequence):
for c in seq.children:
if isinstance(c, SfgSequence):
flatten(c)
else:
children_flattened.append(c)
flatten(sequence)
for c in children_flattened:
self.visit(c)
sequence._children = children_flattened
class ParameterCollector(): class ParameterCollector():
"""Collects all parameters required but not defined in a kernel call tree.
Requires that all sequences in the tree are flattened.
"""
def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]: def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]:
if isinstance(node, SfgCallTreeLeaf): if isinstance(node, SfgCallTreeLeaf):
return self._visit_SfgCallTreeLeaf(node) return self._visit_SfgCallTreeLeaf(node)
...@@ -25,11 +56,10 @@ class ParameterCollector(): ...@@ -25,11 +56,10 @@ class ParameterCollector():
params = set() params = set()
for c in sequence.children[::-1]: for c in sequence.children[::-1]:
if isinstance(c, SfgCallTreeLeaf): if isinstance(c, SfgParameterDefinitionNode):
# Only a leaf in a sequence may effectively define symbols
# Remove these from the required parameters
params -= c.defined_symbols params -= c.defined_symbols
assert not isinstance(c, SfgSequence), "Sequence not flattened."
params |= self.visit(c) params |= self.visit(c)
return params return params
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment