Skip to content
Snippets Groups Projects
Commit 43daedc4 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

some minor refactoring

parent c509ac39
No related branches found
No related tags found
No related merge requests found
# Build System Integration
## Configurator Script
The configurator script should configure the code generator and provide global configuration to all codegen scripts.
In the CMake integration, it can be specified globally via the `PystencilsSfg_CONFIGURATOR_SCRIPT` cache variable.
To decide and implement:
- Use `runpy` and communicate via a global variable, or use `importlib.util.spec_from_file_location` and communicate via
a function call? In either case, there needs to be concensus about at least one name in the configurator script.
- Allow specifying a separate configurator file at `pystencilssfg_generate_target_sources`? Sound sensible... It's basically
for free with the potential to add lots of flexibility
## Generator flags
Two separate lists of flags may be passed to generator scripts: Some may be evaluated by the SFG, and the rest
will be passed on to the user script.
Arguments to the SFG include:
- Path of the configurator script
- Output directory
How to separate user from generator arguments?
# pystencils Source File Generator (ps-sfg) # pystencils Source File Generator (pystencils-sfg)
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence, Set, Union, Iterable from typing import TYPE_CHECKING, Sequence, Set
if TYPE_CHECKING:
from ..context import SfgContext
from ..source_components import SfgHeaderInclude
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from itertools import chain from itertools import chain
from ..kernel_namespace import SfgKernelHandle from ..kernel_namespace import SfgKernelHandle
from ..source_concepts.source_objects import SrcObject, TypedSymbolOrObject from ..source_concepts.source_objects import SrcObject, TypedSymbolOrObject
from ..exceptions import SfgException from ..exceptions import SfgException
from pystencils.typing import TypedSymbol if TYPE_CHECKING:
from ..context import SfgContext
from ..source_components import SfgHeaderInclude
class SfgCallTreeNode(ABC): class SfgCallTreeNode(ABC):
"""Base class for all nodes comprising SFG call trees. """ """Base class for all nodes comprising SFG call trees. """
...@@ -72,7 +69,7 @@ class SfgStatements(SfgCallTreeLeaf): ...@@ -72,7 +69,7 @@ class SfgStatements(SfgCallTreeLeaf):
required_objects: Objects (as `SrcObject` or `TypedSymbol`) that are required as input to these statements. required_objects: Objects (as `SrcObject` or `TypedSymbol`) that are required as input to these statements.
""" """
def __init__(self, def __init__(self,
code_string: str, code_string: str,
defined_params: Sequence[TypedSymbolOrObject], defined_params: Sequence[TypedSymbolOrObject],
required_params: Sequence[TypedSymbolOrObject]): required_params: Sequence[TypedSymbolOrObject]):
......
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence, Set, Union, Iterable
if TYPE_CHECKING: from typing import TYPE_CHECKING, Set
from ..context import SfgContext
from functools import reduce from functools import reduce
...@@ -12,8 +10,13 @@ from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgState ...@@ -12,8 +10,13 @@ from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgState
from .deferred_nodes import SfgParamCollectionDeferredNode from .deferred_nodes import SfgParamCollectionDeferredNode
if TYPE_CHECKING:
from ..context import SfgContext
class FlattenSequences(): class FlattenSequences():
"""Flattens any nested sequences occuring in a kernel call tree.""" """Flattens any nested sequences occuring in a kernel call tree."""
def visit(self, node: SfgCallTreeNode) -> None: def visit(self, node: SfgCallTreeNode) -> None:
if isinstance(node, SfgSequence): if isinstance(node, SfgSequence):
return self._visit_SfgSequence(node) return self._visit_SfgSequence(node)
...@@ -23,14 +26,14 @@ class FlattenSequences(): ...@@ -23,14 +26,14 @@ class FlattenSequences():
def _visit_SfgSequence(self, sequence: SfgSequence) -> None: def _visit_SfgSequence(self, sequence: SfgSequence) -> None:
children_flattened = [] children_flattened = []
def flatten(seq: SfgSequence): def flatten(seq: SfgSequence):
for c in seq.children: for c in seq.children:
if isinstance(c, SfgSequence): if isinstance(c, SfgSequence):
flatten(c) flatten(c)
else: else:
children_flattened.append(c) children_flattened.append(c)
flatten(sequence) flatten(sequence)
for c in children_flattened: for c in children_flattened:
...@@ -49,14 +52,15 @@ class CollectIncludes: ...@@ -49,14 +52,15 @@ class CollectIncludes:
class ExpandingParameterCollector(): class ExpandingParameterCollector():
"""Collects all parameters required but not defined in a kernel call tree.
Expands any deferred nodes of type `SfgParamCollectionDeferredNode` found within sequences on the way.
"""
def __init__(self, ctx: SfgContext) -> None: def __init__(self, ctx: SfgContext) -> None:
self._ctx = ctx self._ctx = ctx
self._flattener = FlattenSequences() self._flattener = FlattenSequences()
"""Collects all parameters required but not defined in a kernel call tree.
Expands any deferred nodes of type `SfgParamCollectionDeferredNode` found within sequences on the way.
"""
def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]: def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]:
if isinstance(node, SfgCallTreeLeaf): if isinstance(node, SfgCallTreeLeaf):
return self._visit_SfgCallTreeLeaf(node) return self._visit_SfgCallTreeLeaf(node)
...@@ -72,13 +76,13 @@ class ExpandingParameterCollector(): ...@@ -72,13 +76,13 @@ class ExpandingParameterCollector():
""" """
Only in a sequence may parameters be defined and visible to subsequent nodes. Only in a sequence may parameters be defined and visible to subsequent nodes.
""" """
params = set() params = set()
def iter_nested_sequences(seq: SfgSequence, visible_params: Set[TypedSymbol]): def iter_nested_sequences(seq: SfgSequence, visible_params: Set[TypedSymbol]):
for i in range(len(seq.children) - 1, -1, -1): for i in range(len(seq.children) - 1, -1, -1):
c = seq.children[i] c = seq.children[i]
if isinstance(c, SfgParamCollectionDeferredNode): if isinstance(c, SfgParamCollectionDeferredNode):
c = c.expand(self._ctx, visible_params=visible_params) c = c.expand(self._ctx, visible_params=visible_params)
seq.replace_child(i, c) seq.replace_child(i, c)
...@@ -88,7 +92,7 @@ class ExpandingParameterCollector(): ...@@ -88,7 +92,7 @@ class ExpandingParameterCollector():
else: else:
if isinstance(c, SfgStatements): if isinstance(c, SfgStatements):
visible_params -= c.defined_parameters visible_params -= c.defined_parameters
visible_params |= self.visit(c) visible_params |= self.visit(c)
iter_nested_sequences(sequence, params) iter_nested_sequences(sequence, params)
...@@ -108,6 +112,7 @@ class ParameterCollector(): ...@@ -108,6 +112,7 @@ class ParameterCollector():
Requires that all sequences in the tree are flattened. Requires that all sequences in the tree are flattened.
""" """
def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]: def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]:
if isinstance(node, SfgCallTreeLeaf): if isinstance(node, SfgCallTreeLeaf):
return self._visit_SfgCallTreeLeaf(node) return self._visit_SfgCallTreeLeaf(node)
...@@ -123,12 +128,12 @@ class ParameterCollector(): ...@@ -123,12 +128,12 @@ class ParameterCollector():
""" """
Only in a sequence may parameters be defined and visible to subsequent nodes. Only in a sequence may parameters be defined and visible to subsequent nodes.
""" """
params = set() params = set()
for c in sequence.children[::-1]: for c in sequence.children[::-1]:
if isinstance(c, SfgStatements): if isinstance(c, SfgStatements):
params -= c.defined_parameters params -= c.defined_parameters
assert not isinstance(c, SfgSequence), "Sequence not flattened." assert not isinstance(c, SfgSequence), "Sequence not flattened."
params |= self.visit(c) params |= self.visit(c)
return params return params
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment