Skip to content
Snippets Groups Projects

Introduction of structural ast nodes

Merged Richard Angersbach requested to merge rangersbach/structural into v2.0-dev
Files
8
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Iterable, Sequence, cast
from types import NoneType
@@ -9,10 +11,35 @@ from ..memory import PsSymbol
from .util import failing_cast
class PsBlock(PsAstNode):
class PsStructuralNode(PsAstNode, ABC):
"""Base class for structural nodes in the pystencils AST.
This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from.
"""
def clone(self):
"""Clone this structure node.
.. note::
Subclasses of `PsStructuralNode` should not override this method,
but implement `_clone_structural` instead.
That implementation shall call `clone` on any of its children.
"""
return self._clone_structural()
@abstractmethod
def _clone_structural(self) -> PsStructuralNode:
"""Implementation of structural node cloning.
:meta public:
"""
pass
class PsBlock(PsStructuralNode):
__match_args__ = ("statements",)
def __init__(self, cs: Iterable[PsAstNode]):
def __init__(self, cs: Iterable[PsStructuralNode]):
self._statements = list(cs)
@property
@@ -21,23 +48,23 @@ class PsBlock(PsAstNode):
@children.setter
def children(self, cs: Sequence[PsAstNode]):
self._statements = list(cs)
self._statements = list([failing_cast(PsStructuralNode, c) for c in cs])
def get_children(self) -> tuple[PsAstNode, ...]:
return tuple(self._statements)
def set_child(self, idx: int, c: PsAstNode):
self._statements[idx] = c
self._statements[idx] = failing_cast(PsStructuralNode, c)
def clone(self) -> PsBlock:
return PsBlock([stmt.clone() for stmt in self._statements])
def _clone_structural(self) -> PsBlock:
return PsBlock([stmt._clone_structural() for stmt in self._statements])
@property
def statements(self) -> list[PsAstNode]:
def statements(self) -> list[PsStructuralNode]:
return self._statements
@statements.setter
def statements(self, stm: Sequence[PsAstNode]):
def statements(self, stm: Sequence[PsStructuralNode]):
self._statements = list(stm)
def __repr__(self) -> str:
@@ -45,7 +72,7 @@ class PsBlock(PsAstNode):
return f"PsBlock( {contents} )"
class PsStatement(PsAstNode):
class PsStatement(PsStructuralNode):
__match_args__ = ("expression",)
def __init__(self, expr: PsExpression):
@@ -59,7 +86,7 @@ class PsStatement(PsAstNode):
def expression(self, expr: PsExpression):
self._expression = expr
def clone(self) -> PsStatement:
def _clone_structural(self) -> PsStatement:
return PsStatement(self._expression.clone())
def get_children(self) -> tuple[PsAstNode, ...]:
@@ -71,7 +98,7 @@ class PsStatement(PsAstNode):
self._expression = failing_cast(PsExpression, c)
class PsAssignment(PsAstNode):
class PsAssignment(PsStructuralNode):
__match_args__ = (
"lhs",
"rhs",
@@ -101,7 +128,7 @@ class PsAssignment(PsAstNode):
def rhs(self, expr: PsExpression):
self._rhs = expr
def clone(self) -> PsAssignment:
def _clone_structural(self) -> PsAssignment:
return PsAssignment(self._lhs.clone(), self._rhs.clone())
def get_children(self) -> tuple[PsAstNode, ...]:
@@ -141,7 +168,7 @@ class PsDeclaration(PsAssignment):
def declared_symbol(self) -> PsSymbol:
return cast(PsSymbolExpr, self._lhs).symbol
def clone(self) -> PsDeclaration:
def _clone_structural(self) -> PsDeclaration:
return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone())
def set_child(self, idx: int, c: PsAstNode):
@@ -157,7 +184,7 @@ class PsDeclaration(PsAssignment):
return f"PsDeclaration({repr(self._lhs)}, {repr(self._rhs)})"
class PsLoop(PsAstNode):
class PsLoop(PsStructuralNode):
__match_args__ = ("counter", "start", "stop", "step", "body")
def __init__(
@@ -214,13 +241,13 @@ class PsLoop(PsAstNode):
def body(self, block: PsBlock):
self._body = block
def clone(self) -> PsLoop:
def _clone_structural(self) -> PsLoop:
return PsLoop(
self._ctr.clone(),
self._start.clone(),
self._stop.clone(),
self._step.clone(),
self._body.clone(),
self._body._clone_structural(),
)
def get_children(self) -> tuple[PsAstNode, ...]:
@@ -243,7 +270,7 @@ class PsLoop(PsAstNode):
assert False, "unreachable code"
class PsConditional(PsAstNode):
class PsConditional(PsStructuralNode):
"""Conditional branch"""
__match_args__ = ("condition", "branch_true", "branch_false")
@@ -282,11 +309,11 @@ class PsConditional(PsAstNode):
def branch_false(self, block: PsBlock | None):
self._branch_false = block
def clone(self) -> PsConditional:
def _clone_structural(self) -> PsConditional:
return PsConditional(
self._condition.clone(),
self._branch_true.clone(),
self._branch_false.clone() if self._branch_false is not None else None,
self._branch_true._clone_structural(),
self._branch_false._clone_structural() if self._branch_false is not None else None,
)
def get_children(self) -> tuple[PsAstNode, ...]:
@@ -317,7 +344,7 @@ class PsEmptyLeafMixIn:
pass
class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode):
"""A C/C++ preprocessor pragma.
Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``.
@@ -335,7 +362,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
def text(self) -> str:
return self._text
def clone(self) -> PsPragma:
def _clone_structural(self) -> PsPragma:
return PsPragma(self.text)
def structurally_equal(self, other: PsAstNode) -> bool:
@@ -345,7 +372,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
return self._text == other._text
class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode):
__match_args__ = ("lines",)
def __init__(self, text: str) -> None:
@@ -360,7 +387,7 @@ class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
def lines(self) -> tuple[str, ...]:
return self._lines
def clone(self) -> PsComment:
def _clone_structural(self) -> PsComment:
return PsComment(self._text)
def structurally_equal(self, other: PsAstNode) -> bool:
Loading