Skip to content
Snippets Groups Projects
Select Git revision
  • 3ddf218bb7b0ff191875ef46a2da836cf82903b9
  • master default protected
  • rangersbach/c-interfacing
  • fhennig/devel
  • v0.1a4
  • v0.1a3
  • v0.1a2
  • v0.1a1
8 results

visitors.py

Blame
  • visitors.py 4.48 KiB
    from __future__ import annotations
    
    from typing import TYPE_CHECKING
    
    from functools import reduce
    
    from .basic_nodes import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements
    from .deferred_nodes import SfgParamCollectionDeferredNode
    from .dispatcher import visitor
    from ..source_concepts.source_objects import TypedSymbolOrObject
    
    if TYPE_CHECKING:
        from ..context import SfgContext
    
    
    class FlattenSequences():
        """Flattens any nested sequences occuring in a kernel call tree."""
    
        @visitor
        def visit(self, node: SfgCallTreeNode) -> None:
            for c in node.children:
                self.visit(c)
    
        @visit.case(SfgSequence)
        def sequence(self, sequence: SfgSequence) -> None:
            children_flattened = []
    
            def flatten(seq: SfgSequence):
                for c in seq.children:
                    if isinstance(c, SfgSequence):
                        flatten(c)
                    else:
                        children_flattened.append(c)
    
            flatten(sequence)
    
            for c in children_flattened:
                self.visit(c)
    
            sequence._children = children_flattened
    
    
    class CollectIncludes:
        def visit(self, node: SfgCallTreeNode):
            includes = node.required_includes
            for c in node.children:
                includes |= self.visit(c)
    
            return includes
    
    
    class ExpandingParameterCollector():
        """Collects all parameters required but not defined in a kernel call tree.
        Expands any deferred nodes of type `SfgParamCollectionDeferredNode` found within sequences on the way.
        """
    
        def __init__(self, ctx: SfgContext) -> None:
            self._ctx = ctx
    
            self._flattener = FlattenSequences()
    
        @visitor
        def visit(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]:
            return self.branching_node(node)
    
        @visit.case(SfgCallTreeLeaf)
        def leaf(self, leaf: SfgCallTreeLeaf) -> set[TypedSymbolOrObject]:
            return leaf.required_parameters
    
        @visit.case(SfgSequence)
        def sequence(self, sequence: SfgSequence) -> set[TypedSymbolOrObject]:
            """
                Only in a sequence may parameters be defined and visible to subsequent nodes.
            """
    
            params: set[TypedSymbolOrObject] = set()
    
            def iter_nested_sequences(seq: SfgSequence, visible_params: set[TypedSymbolOrObject]):
                for i in range(len(seq.children) - 1, -1, -1):
                    c = seq.children[i]
    
                    if isinstance(c, SfgParamCollectionDeferredNode):
                        c = c.expand(self._ctx, visible_params=visible_params)
                        seq[i] = c
    
                    if isinstance(c, SfgSequence):
                        iter_nested_sequences(c, visible_params)
                    else:
                        if isinstance(c, SfgStatements):
                            visible_params -= c.defined_parameters
    
                        visible_params |= self.visit(c)
    
            iter_nested_sequences(sequence, params)
    
            return params
    
        def branching_node(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]:
            """
                Each interior node that is not a sequence simply requires the union of all parameters
                required by its children.
            """
            return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set())
    
    
    class ParameterCollector():
        """Collects all parameters required but not defined in a kernel call tree.
    
        Requires that all sequences in the tree are flattened.
        """
    
        @visitor
        def visit(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]:
            return self.branching_node(node)
    
        @visit.case(SfgCallTreeLeaf)
        def leaf(self, leaf: SfgCallTreeLeaf) -> set[TypedSymbolOrObject]:
            return leaf.required_parameters
    
        @visit.case(SfgSequence)
        def sequence(self, sequence: SfgSequence) -> set[TypedSymbolOrObject]:
            """
                Only in a sequence may parameters be defined and visible to subsequent nodes.
            """
    
            params: set[TypedSymbolOrObject] = set()
            for c in sequence.children[::-1]:
                if isinstance(c, SfgStatements):
                    params -= c.defined_parameters
    
                assert not isinstance(c, SfgSequence), "Sequence not flattened."
                params |= self.visit(c)
            return params
    
        def branching_node(self, node: SfgCallTreeNode) -> set[TypedSymbolOrObject]:
            """
                Each interior node that is not a sequence simply requires the union of all parameters
                required by its children.
            """
            return reduce(lambda x, y: x | y, (self.visit(c) for c in node.children), set())