Skip to content
Snippets Groups Projects
Commit 43daedc4 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

some minor refactoring

parent c509ac39
No related branches found
No related tags found
No related merge requests found
# Build System Integration
## Configurator Script
The configurator script should configure the code generator and provide global configuration to all codegen scripts.
In the CMake integration, it can be specified globally via the `PystencilsSfg_CONFIGURATOR_SCRIPT` cache variable.
To decide and implement:
- Use `runpy` and communicate via a global variable, or use `importlib.util.spec_from_file_location` and communicate via
a function call? In either case, there needs to be concensus about at least one name in the configurator script.
- Allow specifying a separate configurator file at `pystencilssfg_generate_target_sources`? Sound sensible... It's basically
for free with the potential to add lots of flexibility
## Generator flags
Two separate lists of flags may be passed to generator scripts: Some may be evaluated by the SFG, and the rest
will be passed on to the user script.
Arguments to the SFG include:
- Path of the configurator script
- Output directory
How to separate user from generator arguments?
# pystencils Source File Generator (ps-sfg)
# pystencils Source File Generator (pystencils-sfg)
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence, Set, Union, Iterable
if TYPE_CHECKING:
from ..context import SfgContext
from ..source_components import SfgHeaderInclude
from typing import TYPE_CHECKING, Sequence, Set
from abc import ABC, abstractmethod
from itertools import chain
from ..kernel_namespace import SfgKernelHandle
from ..source_concepts.source_objects import SrcObject, TypedSymbolOrObject
from ..exceptions import SfgException
from pystencils.typing import TypedSymbol
if TYPE_CHECKING:
from ..context import SfgContext
from ..source_components import SfgHeaderInclude
class SfgCallTreeNode(ABC):
"""Base class for all nodes comprising SFG call trees. """
......@@ -72,7 +69,7 @@ class SfgStatements(SfgCallTreeLeaf):
required_objects: Objects (as `SrcObject` or `TypedSymbol`) that are required as input to these statements.
"""
def __init__(self,
def __init__(self,
code_string: str,
defined_params: Sequence[TypedSymbolOrObject],
required_params: Sequence[TypedSymbolOrObject]):
......
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence, Set, Union, Iterable
if TYPE_CHECKING:
from ..context import SfgContext
from typing import TYPE_CHECKING, Set
from functools import reduce
......@@ -12,8 +10,13 @@ from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgState
from .deferred_nodes import SfgParamCollectionDeferredNode
if TYPE_CHECKING:
from ..context import SfgContext
class FlattenSequences():
"""Flattens any nested sequences occuring in a kernel call tree."""
def visit(self, node: SfgCallTreeNode) -> None:
if isinstance(node, SfgSequence):
return self._visit_SfgSequence(node)
......@@ -23,14 +26,14 @@ class FlattenSequences():
def _visit_SfgSequence(self, sequence: SfgSequence) -> None:
children_flattened = []
def flatten(seq: SfgSequence):
for c in seq.children:
if isinstance(c, SfgSequence):
flatten(c)
else:
children_flattened.append(c)
flatten(sequence)
for c in children_flattened:
......@@ -49,14 +52,15 @@ class CollectIncludes:
class ExpandingParameterCollector():
"""Collects all parameters required but not defined in a kernel call tree.
Expands any deferred nodes of type `SfgParamCollectionDeferredNode` found within sequences on the way.
"""
def __init__(self, ctx: SfgContext) -> None:
self._ctx = ctx
self._flattener = FlattenSequences()
"""Collects all parameters required but not defined in a kernel call tree.
Expands any deferred nodes of type `SfgParamCollectionDeferredNode` found within sequences on the way.
"""
def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]:
if isinstance(node, SfgCallTreeLeaf):
return self._visit_SfgCallTreeLeaf(node)
......@@ -72,13 +76,13 @@ class ExpandingParameterCollector():
"""
Only in a sequence may parameters be defined and visible to subsequent nodes.
"""
params = set()
def iter_nested_sequences(seq: SfgSequence, visible_params: Set[TypedSymbol]):
for i in range(len(seq.children) - 1, -1, -1):
c = seq.children[i]
if isinstance(c, SfgParamCollectionDeferredNode):
c = c.expand(self._ctx, visible_params=visible_params)
seq.replace_child(i, c)
......@@ -88,7 +92,7 @@ class ExpandingParameterCollector():
else:
if isinstance(c, SfgStatements):
visible_params -= c.defined_parameters
visible_params |= self.visit(c)
iter_nested_sequences(sequence, params)
......@@ -108,6 +112,7 @@ class ParameterCollector():
Requires that all sequences in the tree are flattened.
"""
def visit(self, node: SfgCallTreeNode) -> Set[TypedSymbol]:
if isinstance(node, SfgCallTreeLeaf):
return self._visit_SfgCallTreeLeaf(node)
......@@ -123,12 +128,12 @@ class ParameterCollector():
"""
Only in a sequence may parameters be defined and visible to subsequent nodes.
"""
params = set()
for c in sequence.children[::-1]:
if isinstance(c, SfgStatements):
params -= c.defined_parameters
assert not isinstance(c, SfgSequence), "Sequence not flattened."
params |= self.visit(c)
return params
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment