Skip to content
Snippets Groups Projects
Commit ed9315af authored by Frederik Hennig's avatar Frederik Hennig
Browse files

simplify children members of AST. introduce conditional branch.

parent d3a4ce73
No related branches found
No related tags found
No related merge requests found
Pipeline #60569 failed
......@@ -7,6 +7,7 @@ from .nodes import (
PsAssignment,
PsDeclaration,
PsLoop,
PsConditional,
)
from .kernelfunction import PsKernelFunction
......@@ -24,5 +25,6 @@ __all__ = [
"PsAssignment",
"PsDeclaration",
"PsLoop",
"PsConditional",
"ast_subs"
]
......@@ -102,7 +102,7 @@ class RequiredHeadersCollector(Collector):
case PsExpression(expr):
return self.rec(expr)
case node:
return reduce(set.union, (self(c) for c in node.children()), set())
return reduce(set.union, (self(c) for c in node.children), set())
def map_typed_variable(self, var: PsTypedVariable) -> set[str]:
return var.dtype.required_headers
......
from __future__ import annotations
from typing import Generator
from dataclasses import dataclass
from pymbolic.mapper.dependency import DependencyMapper
......@@ -104,16 +103,8 @@ class PsKernelFunction(PsAstNode):
"""For backward compatibility"""
return None
def num_children(self) -> int:
return 1
def children(self) -> Generator[PsAstNode, None, None]:
yield from (self._body,)
def get_child(self, idx: int):
if idx not in (0, -1):
raise IndexError(f"Child index out of bounds: {idx}")
return self._body
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._body,)
def set_child(self, idx: int, c: PsAstNode):
if idx not in (0, -1):
......
from __future__ import annotations
from typing import Sequence, Generator, Iterable, cast, TypeAlias
from typing import Sequence, Iterable, cast, TypeAlias
from types import NoneType
from abc import ABC, abstractmethod
......@@ -12,32 +13,28 @@ class PsAstNode(ABC):
"""Base class for all nodes in the pystencils AST.
This base class provides a common interface to inspect and update the AST's branching structure.
The four methods `num_children`, `children`, `get_child` and `set_child` must be implemented by
each subclass.
The two methods `get_children` and `set_child` must be implemented by each subclass.
Subclasses are also responsible for doing the necessary type checks if they place restrictions on
the types of their children.
"""
@abstractmethod
def num_children(self) -> int:
...
@property
def children(self) -> tuple[PsAstNode, ...]:
return self.get_children()
@abstractmethod
def children(self) -> Generator[PsAstNode, None, None]:
...
@children.setter
def children(self, cs: Iterable[PsAstNode]):
for i, c in enumerate(cs):
self.set_child(i, c)
@abstractmethod
def get_child(self, idx: int):
def get_children(self) -> tuple[PsAstNode, ...]:
...
@abstractmethod
def set_child(self, idx: int, c: PsAstNode):
...
def set_children(self, cs: Iterable[PsAstNode]):
for i, c in enumerate(cs):
self.set_child(i, c)
class PsBlock(PsAstNode):
__match_args__ = ("statements",)
......@@ -45,14 +42,8 @@ class PsBlock(PsAstNode):
def __init__(self, cs: Sequence[PsAstNode]):
self._statements = list(cs)
def num_children(self) -> int:
return len(self._statements)
def children(self) -> Generator[PsAstNode, None, None]:
yield from self._statements
def get_child(self, idx: int):
return self._statements[idx]
def get_children(self) -> tuple[PsAstNode, ...]:
return tuple(self._statements)
def set_child(self, idx: int, c: PsAstNode):
self._statements[idx] = c
......@@ -67,14 +58,8 @@ class PsBlock(PsAstNode):
class PsLeafNode(PsAstNode):
def num_children(self) -> int:
return 0
def children(self) -> Generator[PsAstNode, None, None]:
yield from ()
def get_child(self, idx: int):
raise IndexError("Child index out of bounds: Leaf nodes have no children.")
def get_children(self) -> tuple[PsAstNode, ...]:
return ()
def set_child(self, idx: int, c: PsAstNode):
raise IndexError("Child index out of bounds: Leaf nodes have no children.")
......@@ -154,14 +139,8 @@ class PsAssignment(PsAstNode):
def rhs(self, expr: PsExpression):
self._rhs = expr
def num_children(self) -> int:
return 2
def children(self) -> Generator[PsAstNode, None, None]:
yield from (self._lhs, self._rhs)
def get_child(self, idx: int):
return (self._lhs, self._rhs)[idx]
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._lhs, self._rhs)
def set_child(self, idx: int, c: PsAstNode):
idx = [0, 1][idx] # trick to normalize index
......@@ -265,14 +244,8 @@ class PsLoop(PsAstNode):
def body(self, block: PsBlock):
self._body = block
def num_children(self) -> int:
return 5
def children(self) -> Generator[PsAstNode, None, None]:
yield from (self._ctr, self._start, self._stop, self._step, self._body)
def get_child(self, idx: int):
return (self._ctr, self._start, self._stop, self._step, self._body)[idx]
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._ctr, self._start, self._stop, self._step, self._body)
def set_child(self, idx: int, c: PsAstNode):
idx = list(range(5))[idx]
......@@ -289,3 +262,60 @@ class PsLoop(PsAstNode):
self._body = failing_cast(PsBlock, c)
case _:
assert False, "unreachable code"
class PsConditional(PsAstNode):
"""Conditional branch"""
__match_args__ = ("condition", "branch_true", "branch_false")
def __init__(
self,
cond: PsExpression,
branch_true: PsBlock,
branch_false: PsBlock | None = None,
):
self._condition = cond
self._branch_true = branch_true
self._branch_false = branch_false
@property
def condition(self) -> PsExpression:
return self._condition
@condition.setter
def condition(self, expr: PsExpression):
self._condition = expr
@property
def branch_true(self) -> PsBlock:
return self._branch_true
@branch_true.setter
def branch_true(self, block: PsBlock):
self._branch_true = block
@property
def branch_false(self) -> PsBlock | None:
return self._branch_false
@branch_false.setter
def branch_false(self, block: PsBlock | None):
self._branch_false = block
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._condition, self._branch_true) + (
(self._branch_false,) if self._branch_false is not None else ()
)
def set_child(self, idx: int, c: PsAstNode):
idx = list(range(3))[idx]
match idx:
case 0:
self._condition = failing_cast(PsExpression, c)
case 1:
self._branch_true = failing_cast(PsBlock, c)
case 2:
self._branch_false = failing_cast((PsBlock, NoneType), c)
case _:
assert False, "unreachable code"
......@@ -12,7 +12,7 @@ from .nodes import PsAstNode, PsAssignment, PsLoop, PsExpression
class PsAstTransformer(ABC):
def transform_children(self, node: PsAstNode, *args, **kwargs):
node.set_children(self.visit(c, *args, **kwargs) for c in node.children())
node.children = tuple(self.visit(c, *args, **kwargs) for c in node.children)
@ast_visitor
def visit(self, node, *args, **kwargs):
......
from typing import TypeVar
from typing import Any
T = TypeVar("T")
def failing_cast(target: type, obj: T):
def failing_cast(target: type | tuple[type, ...], obj: Any) -> Any:
if not isinstance(obj, target):
raise TypeError(f"Casting {obj} to {target} failed.")
return obj
......@@ -2,14 +2,23 @@ from __future__ import annotations
from pymbolic.mapper.c_code import CCodeMapper
from .ast import ast_visitor, PsAstNode, PsBlock, PsExpression, PsDeclaration, PsAssignment, PsLoop
from .ast import (
ast_visitor,
PsAstNode,
PsBlock,
PsExpression,
PsDeclaration,
PsAssignment,
PsLoop,
PsConditional,
)
from .ast.kernelfunction import PsKernelFunction
def emit_code(kernel: PsKernelFunction):
# TODO: Specialize for different targets
printer = CPrinter()
return printer.print(kernel)
return printer.print(kernel)
class CPrinter:
......@@ -17,7 +26,6 @@ class CPrinter:
self._indent_width = indent_width
self._current_indent_level = 0
self._inside_expression = False # controls parentheses in nested arithmetic expressions
self._pb_cmapper = CCodeMapper()
......@@ -30,7 +38,7 @@ class CPrinter:
@ast_visitor
def visit(self, _: PsAstNode) -> str:
raise ValueError("Cannot print this node.")
@visit.case(PsKernelFunction)
def function(self, func: PsKernelFunction) -> str:
params_spec = func.get_parameters()
......@@ -41,11 +49,11 @@ class CPrinter:
@visit.case(PsBlock)
def block(self, block: PsBlock):
if not block.children():
if not block.children:
return self.indent("{ }")
self._current_indent_level += self._indent_width
interior = "\n".join(self.visit(c) for c in block.children())
interior = "\n".join(self.visit(c) for c in block.children)
self._current_indent_level -= self._indent_width
return self.indent("{\n") + interior + self.indent("}\n")
......@@ -77,8 +85,23 @@ class CPrinter:
body_code = self.visit(loop.body)
code = f"for({ctr_symbol.dtype} {ctr} = {start_code};" + \
f" {ctr} < {stop_code};" + \
f" {ctr} += {step_code})\n" + \
body_code
return code
code = (
f"for({ctr_symbol.dtype} {ctr} = {start_code};"
+ f" {ctr} < {stop_code};"
+ f" {ctr} += {step_code})\n"
+ body_code
)
return self.indent(code)
@visit.case(PsConditional)
def conditional(self, node: PsConditional):
cond_code = self.visit(node.condition)
then_code = self.visit(node.branch_true)
code = f"if({cond_code})\n{then_code}"
if node.branch_false is not None:
else_code = self.visit(node.branch_false)
code += f"\nelse\n{else_code}"
return self.indent(code)
......@@ -22,7 +22,6 @@ from ..arrays import (
)
from ..types import (
PsAbstractType,
PsScalarType,
PsUnsignedIntegerType,
PsSignedIntegerType,
PsIeeeFloatType,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment