Skip to content
Snippets Groups Projects
Commit d2520fd3 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

refactored branching structure, satisfied mypy.

parent 7b552412
No related branches found
No related tags found
No related merge requests found
Pipeline #59671 failed
from .nodes import ( from .nodes import (
PsAstNode, PsBlock, PsExpression, PsLvalueExpr, PsSymbolExpr, PsAstNode,
PsAssignment, PsDeclaration, PsLoop PsBlock,
PsExpression,
PsLvalueExpr,
PsSymbolExpr,
PsAssignment,
PsDeclaration,
PsLoop,
) )
from .dispatcher import ast_visitor from .dispatcher import ast_visitor
from .transformations import ast_subs from .transformations import ast_subs
__all__ = [ __all__ = [
ast_visitor, "ast_visitor",
PsAstNode, PsBlock, PsExpression, PsLvalueExpr, PsSymbolExpr, PsAssignment, PsDeclaration, PsLoop, "PsAstNode",
ast_subs "PsBlock",
"PsExpression",
"PsLvalueExpr",
"PsSymbolExpr",
"PsAssignment",
"PsDeclaration",
"PsLoop",
"ast_subs",
] ]
from __future__ import annotations from __future__ import annotations
from typing import Sequence, Generator from typing import Sequence, Generator, TypeVar, Iterable, cast
from abc import ABC from abc import ABC, abstractmethod
import pymbolic.primitives as pb import pymbolic.primitives as pb
from ..typed_expressions import PsTypedVariable, PsLvalue from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue
class PsAstNode(ABC): T = TypeVar("T")
"""Base class for all nodes in the pystencils AST.""" def failing_cast(target: type, obj: T):
if not isinstance(obj, target):
raise TypeError(f"Casting {obj} to {target} failed.")
return obj
def __init__(self, *children: Sequence[PsAstNode]):
for c in children:
if not isinstance(c, PsAstNode):
raise TypeError(f"Child {c} was not a PsAstNode.")
self._children = list(children)
@property class PsAstNode(ABC):
"""Base class for all nodes in the pystencils AST.
This base class provides a common interface to inspect and update the AST's branching structure.
The four methods `num_children`, `children`, `get_child` and `set_child` must be implemented by
each subclass.
Subclasses are also responsible for doing the necessary type checks if they place restrictions on
the types of their children.
"""
@abstractmethod
def num_children(self) -> int:
...
@abstractmethod
def children(self) -> Generator[PsAstNode, None, None]: def children(self) -> Generator[PsAstNode, None, None]:
yield from self._children ...
def child(self, idx: int): @abstractmethod
return self._children[idx] def get_child(self, idx: int):
...
@children.setter @abstractmethod
def children(self, cs: Sequence[PsAstNode]): def set_child(self, idx: int, c: PsAstNode):
if len(cs) != len(self._children): ...
raise ValueError("The number of child nodes must remain the same!")
def set_children(self, cs: Iterable[PsAstNode]):
for i, c in enumerate(cs):
self.set_child(i, c)
class PsBlock(PsAstNode):
def __init__(self, cs: Sequence[PsAstNode]):
self._children = list(cs) self._children = list(cs)
def __getitem__(self, idx: int): def num_children(self) -> int:
return len(self._children)
def children(self) -> Generator[PsAstNode, None, None]:
yield from self._children
def get_child(self, idx: int):
return self._children[idx] return self._children[idx]
def __setitem__(self, idx: int, c: PsAstNode): def set_child(self, idx: int, c: PsAstNode):
self._children[idx] = c self._children[idx] = c
class PsLeafNode(PsAstNode):
def num_children(self) -> int:
return 0
class PsBlock(PsAstNode):
@property
def children(self) -> Generator[PsAstNode, None, None]: def children(self) -> Generator[PsAstNode, None, None]:
yield from self._children # need to override entire property to override the setter yield from ()
@children.setter def get_child(self, idx: int):
def children(self, cs: Sequence[PsAstNode]): raise IndexError("Child index out of bounds: Leaf nodes have no children.")
self._children = cs
def set_child(self, idx: int, c: PsAstNode):
raise IndexError("Child index out of bounds: Leaf nodes have no children.")
class PsExpression(PsAstNode):
class PsExpression(PsLeafNode):
"""Wrapper around pymbolics expressions.""" """Wrapper around pymbolics expressions."""
def __init__(self, expr: pb.Expression): def __init__(self, expr: pb.Expression):
super().__init__()
self._expr = expr self._expr = expr
@property @property
...@@ -68,7 +95,7 @@ class PsLvalueExpr(PsExpression): ...@@ -68,7 +95,7 @@ class PsLvalueExpr(PsExpression):
"""Wrapper around pymbolics expressions that may occur at the left-hand side of an assignment""" """Wrapper around pymbolics expressions that may occur at the left-hand side of an assignment"""
def __init__(self, expr: PsLvalue): def __init__(self, expr: PsLvalue):
if not isinstance(expr, PsLvalue): if not isinstance(expr, (PsTypedVariable, PsArrayAccess)):
raise TypeError("Expression was not a valid lvalue") raise TypeError("Expression was not a valid lvalue")
super(PsLvalueExpr, self).__init__(expr) super(PsLvalueExpr, self).__init__(expr)
...@@ -78,52 +105,85 @@ class PsSymbolExpr(PsLvalueExpr): ...@@ -78,52 +105,85 @@ class PsSymbolExpr(PsLvalueExpr):
"""Wrapper around PsTypedSymbols""" """Wrapper around PsTypedSymbols"""
def __init__(self, symbol: PsTypedVariable): def __init__(self, symbol: PsTypedVariable):
if not isinstance(symbol, PsTypedVariable): super().__init__(symbol)
raise TypeError("Not a symbol!")
super(PsLvalueExpr, self).__init__(symbol)
@property @property
def symbol(self) -> PsSymbolExpr: def symbol(self) -> PsTypedVariable:
return self.expression return cast(PsTypedVariable, self._expr)
@symbol.setter @symbol.setter
def symbol(self, symbol: PsSymbolExpr): def symbol(self, symbol: PsTypedVariable):
self.expression = symbol self._expr = symbol
class PsAssignment(PsAstNode): class PsAssignment(PsAstNode):
def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression): def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression):
super(PsAssignment, self).__init__(lhs, rhs) self._lhs = lhs
self._rhs = rhs
@property @property
def lhs(self) -> PsLvalueExpr: def lhs(self) -> PsLvalueExpr:
return self._children[0] return self._lhs
@lhs.setter @lhs.setter
def lhs(self, lvalue: PsLvalueExpr): def lhs(self, lvalue: PsLvalueExpr):
self._children[0] = lvalue self._lhs = lvalue
@property @property
def rhs(self) -> PsExpression: def rhs(self) -> PsExpression:
return self._children[1] return self._rhs
@rhs.setter @rhs.setter
def rhs(self, expr: PsExpression): def rhs(self, expr: PsExpression):
self._children[1] = expr self._rhs = expr
def num_children(self) -> int:
return 2
def children(self) -> Generator[PsAstNode, None, None]:
yield from (self._lhs, self._rhs)
def get_child(self, idx: int):
return (self._lhs, self._rhs)[idx]
def set_child(self, idx: int, c: PsAstNode):
idx = [0, 1][idx] # trick to normalize index
if idx == 0:
self._lhs = failing_cast(PsLvalueExpr, c)
elif idx == 1:
self._rhs = failing_cast(PsExpression, c)
else:
assert False, "unreachable code"
class PsDeclaration(PsAssignment): class PsDeclaration(PsAssignment):
def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression): def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression):
super(PsDeclaration, self).__init__(lhs, rhs) super().__init__(lhs, rhs)
@property @property
def lhs(self) -> PsSymbolExpr: def lhs(self) -> PsLvalueExpr:
return self._children[0] return self._lhs
@lhs.setter @lhs.setter
def lhs(self, symbol_node: PsSymbolExpr): def lhs(self, lvalue: PsLvalueExpr):
self._children[0] = symbol_node self._lhs = failing_cast(PsSymbolExpr, lvalue)
@property
def declared_symbol(self) -> PsSymbolExpr:
return cast(PsSymbolExpr, self._lhs)
@declared_symbol.setter
def declared_symbol(self, lvalue: PsSymbolExpr):
self._lhs = lvalue
def set_child(self, idx: int, c: PsAstNode):
idx = [0, 1][idx] # trick to normalize index
if idx == 0:
self._lhs = failing_cast(PsSymbolExpr, c)
elif idx == 1:
self._rhs = failing_cast(PsExpression, c)
else:
assert False, "unreachable code"
class PsLoop(PsAstNode): class PsLoop(PsAstNode):
...@@ -133,40 +193,68 @@ class PsLoop(PsAstNode): ...@@ -133,40 +193,68 @@ class PsLoop(PsAstNode):
stop: PsExpression, stop: PsExpression,
step: PsExpression, step: PsExpression,
body: PsBlock): body: PsBlock):
super(PsLoop, self).__init__(ctr, start, stop, step, body) self._ctr = ctr
self._start = start
self._stop = stop
self._step = step
self._body = body
@property @property
def counter(self) -> PsSymbolExpr: def counter(self) -> PsSymbolExpr:
return self._children[0] return self._ctr
@counter.setter
def counter(self, expr: PsSymbolExpr):
self._ctr = expr
@property @property
def start(self) -> PsExpression: def start(self) -> PsExpression:
return self._children[1] return self._start
@start.setter @start.setter
def start(self, expr: PsExpression): def start(self, expr: PsExpression):
self._children[1] = expr self._start = expr
@property @property
def stop(self) -> PsExpression: def stop(self) -> PsExpression:
return self._children[2] return self._stop
@stop.setter @stop.setter
def stop(self, expr: PsExpression): def stop(self, expr: PsExpression):
self._children[2] = expr self._stop = expr
@property @property
def step(self) -> PsExpression: def step(self) -> PsExpression:
return self._children[3] return self._step
@step.setter @step.setter
def step(self, expr: PsExpression): def step(self, expr: PsExpression):
self._children[3] = expr self._step = expr
@property @property
def body(self) -> PsBlock: def body(self) -> PsBlock:
return self._children[4] return self._body
@body.setter @body.setter
def body(self, block: PsBlock): def body(self, block: PsBlock):
self._children[4] = block self._body = block
def num_children(self) -> int:
return 5
def children(self) -> Generator[PsAstNode, None, None]:
yield from (self._ctr, self._start, self._stop, self._step, self._body)
def get_child(self, idx: int):
return (self._ctr, self._start, self._stop, self._step, self._body)[idx]
def set_child(self, idx: int, c: PsAstNode):
idx = list(range(5))[idx]
match idx:
case 0: self._ctr = failing_cast(PsSymbolExpr, c)
case 1: self._start = failing_cast(PsExpression, c)
case 2: self._stop = failing_cast(PsExpression, c)
case 3: self._step = failing_cast(PsExpression, c)
case 4: self._body = failing_cast(PsBlock, c)
case _: assert False, "unreachable code"
...@@ -12,7 +12,7 @@ from .nodes import PsAstNode, PsAssignment, PsLoop, PsExpression ...@@ -12,7 +12,7 @@ from .nodes import PsAstNode, PsAssignment, PsLoop, PsExpression
class PsAstTransformer(ABC): class PsAstTransformer(ABC):
def transform_children(self, node: PsAstNode, *args, **kwargs): def transform_children(self, node: PsAstNode, *args, **kwargs):
node.children = [self.visit(c, *args, **kwargs) for c in node.children] node.set_children(self.visit(c, *args, **kwargs) for c in node.children())
@ast_visitor @ast_visitor
def visit(self, node, *args, **kwargs): def visit(self, node, *args, **kwargs):
......
...@@ -26,11 +26,11 @@ class CPrinter: ...@@ -26,11 +26,11 @@ class CPrinter:
@visit.case(PsBlock) @visit.case(PsBlock)
def block(self, block: PsBlock): def block(self, block: PsBlock):
if not block.children: if not block.children():
return self.indent("{ }") return self.indent("{ }")
self._current_indent_level += self._indent_width self._current_indent_level += self._indent_width
interior = "".join(self.visit(c) for c in block.children) interior = "".join(self.visit(c) for c in block.children())
self._current_indent_level -= self._indent_width self._current_indent_level -= self._indent_width
return self.indent("{\n") + interior + self.indent("}\n") return self.indent("{\n") + interior + self.indent("}\n")
...@@ -40,7 +40,7 @@ class CPrinter: ...@@ -40,7 +40,7 @@ class CPrinter:
@visit.case(PsDeclaration) @visit.case(PsDeclaration)
def declaration(self, decl: PsDeclaration): def declaration(self, decl: PsDeclaration):
lhs_symb = decl.lhs.symbol lhs_symb = decl.declared_symbol.symbol
lhs_dtype = lhs_symb.dtype lhs_dtype = lhs_symb.dtype
rhs_code = self.visit(decl.rhs) rhs_code = self.visit(decl.rhs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment