Skip to content
Snippets Groups Projects
Commit 25d72d18 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

add pskernelfunction and parameter collector

parent be1a46b6
No related branches found
No related tags found
No related merge requests found
...@@ -22,5 +22,5 @@ __all__ = [ ...@@ -22,5 +22,5 @@ __all__ = [
"PsAssignment", "PsAssignment",
"PsDeclaration", "PsDeclaration",
"PsLoop", "PsLoop",
"ast_subs", "ast_subs"
] ]
from functools import reduce
from typing import Any, cast
from pymbolic.primitives import Variable
from pymbolic.mapper.dependency import DependencyMapper
from .kernelfunction import PsKernelFunction
from .nodes import PsAstNode, PsExpression, PsAssignment, PsDeclaration, PsLoop, PsBlock
from ..typed_expressions import PsTypedVariable
from ..exceptions import PsMalformedAstException, PsInternalCompilerError
class UndefinedVariablesCollector:
"""Collector for undefined variables.
This class implements an AST visitor that collects all `PsTypedVariable`s that have been used
in the AST without being defined prior to their usage.
"""
def __init__(self) -> None:
self._pb_dep_mapper = DependencyMapper(
include_subscripts=False,
include_lookups=False,
include_calls=False,
include_cses=False,
)
def collect(self, node: PsAstNode) -> set[PsTypedVariable]:
"""Returns all `PsTypedVariable`s that occur in the given AST without being defined prior to their usage."""
match node:
case PsKernelFunction(block):
return self.collect(block)
case PsExpression(expr):
variables: set[Variable] = self._pb_dep_mapper(expr)
for var in variables:
if not isinstance(var, PsTypedVariable):
raise PsMalformedAstException(
f"Non-typed variable {var} encountered"
)
return cast(set[PsTypedVariable], variables)
case PsAssignment(lhs, rhs):
return self.collect(lhs) | self.collect(rhs)
case PsBlock(statements):
undefined_vars = set()
for stmt in statements[::-1]:
undefined_vars -= self.declared_variables(stmt)
undefined_vars |= self.collect(stmt)
return undefined_vars
case PsLoop(ctr, start, stop, step, body):
undefined_vars = (
self.collect(start)
| self.collect(stop)
| self.collect(step)
| self.collect(body)
)
undefined_vars.remove(ctr.symbol)
return undefined_vars
case unknown:
raise PsInternalCompilerError(
f"Don't know how to collect undefined variables from {unknown}"
)
def declared_variables(self, node: PsAstNode) -> set[PsTypedVariable]:
"""Returns the set of variables declared by the given node which are visible in the enclosing scope."""
match node:
case PsDeclaration(lhs, _):
return {lhs.symbol}
case PsAssignment() | PsExpression() | PsLoop() | PsBlock():
return set()
case unknown:
raise PsInternalCompilerError(
f"Don't know how to collect declared variables from {unknown}"
)
from typing import Sequence
from typing import Generator
from .nodes import PsAstNode, PsBlock, failing_cast
from ..typed_expressions import PsTypedVariable
from ...enums import Target
class PsKernelFunction(PsAstNode):
"""A complete pystencils kernel function."""
__match_args__ = ("block",)
def __init__(self, body: PsBlock, target: Target, name: str = "kernel"):
self._body = body
self._target = target
self._name = name
@property
def target(self) -> Target:
"""See pystencils.Target"""
return self._target
@property
def body(self) -> PsBlock:
return self._body
@body.setter
def body(self, body: PsBlock):
self._body = body
@property
def name(self) -> str:
return self._name
@name.setter
def name(self, value: str):
self._name = value
def num_children(self) -> int:
return 1
def children(self) -> Generator[PsAstNode, None, None]:
yield from (self._body, )
def get_child(self, idx: int):
if idx not in (0, -1):
raise IndexError(f"Child index out of bounds: {idx}")
return self._body
def set_child(self, idx: int, c: PsAstNode):
if idx not in (0, -1):
raise IndexError(f"Child index out of bounds: {idx}")
self._body = failing_cast(PsBlock, c)
def get_parameters(self) -> Sequence[PsTypedVariable]:
"""Collect the list of parameters to this function.
This function performs a full traversal of the AST.
To improve performance, make sure to cache the result if necessary.
"""
from .analysis import UndefinedVariablesCollector
params = UndefinedVariablesCollector().collect(self)
return sorted(params, key=lambda p: p.name)
\ No newline at end of file
from __future__ import annotations from __future__ import annotations
from typing import Sequence, Generator, TypeVar, Iterable, cast from typing import Sequence, Generator, Iterable, cast
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import pymbolic.primitives as pb import pymbolic.primitives as pb
from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue
from .util import failing_cast
T = TypeVar("T")
def failing_cast(target: type, obj: T):
if not isinstance(obj, target):
raise TypeError(f"Casting {obj} to {target} failed.")
return obj
class PsAstNode(ABC): class PsAstNode(ABC):
...@@ -49,20 +41,31 @@ class PsAstNode(ABC): ...@@ -49,20 +41,31 @@ class PsAstNode(ABC):
class PsBlock(PsAstNode): class PsBlock(PsAstNode):
__match_args__ = ("statements",)
def __init__(self, cs: Sequence[PsAstNode]): def __init__(self, cs: Sequence[PsAstNode]):
self._children = list(cs) self._statements = list(cs)
def num_children(self) -> int: def num_children(self) -> int:
return len(self._children) return len(self._statements)
def children(self) -> Generator[PsAstNode, None, None]: def children(self) -> Generator[PsAstNode, None, None]:
yield from self._children yield from self._statements
def get_child(self, idx: int): def get_child(self, idx: int):
return self._children[idx] return self._statements[idx]
def set_child(self, idx: int, c: PsAstNode): def set_child(self, idx: int, c: PsAstNode):
self._children[idx] = c self._statements[idx] = c
@property
def statements(self) -> list[PsAstNode]:
return self._statements
@statements.setter
def statemetns(self, stm: Sequence[PsAstNode]):
self._statements = list(stm)
class PsLeafNode(PsAstNode): class PsLeafNode(PsAstNode):
...@@ -82,6 +85,8 @@ class PsLeafNode(PsAstNode): ...@@ -82,6 +85,8 @@ class PsLeafNode(PsAstNode):
class PsExpression(PsLeafNode): class PsExpression(PsLeafNode):
"""Wrapper around pymbolics expressions.""" """Wrapper around pymbolics expressions."""
__match_args__ = ("expression",)
def __init__(self, expr: pb.Expression): def __init__(self, expr: pb.Expression):
self._expr = expr self._expr = expr
...@@ -107,6 +112,8 @@ class PsLvalueExpr(PsExpression): ...@@ -107,6 +112,8 @@ class PsLvalueExpr(PsExpression):
class PsSymbolExpr(PsLvalueExpr): class PsSymbolExpr(PsLvalueExpr):
"""Wrapper around PsTypedSymbols""" """Wrapper around PsTypedSymbols"""
__match_args__ = ("symbol",)
def __init__(self, symbol: PsTypedVariable): def __init__(self, symbol: PsTypedVariable):
super().__init__(symbol) super().__init__(symbol)
...@@ -120,6 +127,9 @@ class PsSymbolExpr(PsLvalueExpr): ...@@ -120,6 +127,9 @@ class PsSymbolExpr(PsLvalueExpr):
class PsAssignment(PsAstNode): class PsAssignment(PsAstNode):
__match_args__ = ("lhs", "rhs",)
def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression): def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression):
self._lhs = lhs self._lhs = lhs
self._rhs = rhs self._rhs = rhs
...@@ -160,6 +170,9 @@ class PsAssignment(PsAstNode): ...@@ -160,6 +170,9 @@ class PsAssignment(PsAstNode):
class PsDeclaration(PsAssignment): class PsDeclaration(PsAssignment):
__match_args__ = ("declared_variable", "rhs",)
def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression): def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression):
super().__init__(lhs, rhs) super().__init__(lhs, rhs)
...@@ -172,11 +185,11 @@ class PsDeclaration(PsAssignment): ...@@ -172,11 +185,11 @@ class PsDeclaration(PsAssignment):
self._lhs = failing_cast(PsSymbolExpr, lvalue) self._lhs = failing_cast(PsSymbolExpr, lvalue)
@property @property
def declared_symbol(self) -> PsSymbolExpr: def declared_variable(self) -> PsSymbolExpr:
return cast(PsSymbolExpr, self._lhs) return cast(PsSymbolExpr, self._lhs)
@declared_symbol.setter @declared_variable.setter
def declared_symbol(self, lvalue: PsSymbolExpr): def declared_variable(self, lvalue: PsSymbolExpr):
self._lhs = lvalue self._lhs = lvalue
def set_child(self, idx: int, c: PsAstNode): def set_child(self, idx: int, c: PsAstNode):
...@@ -190,6 +203,9 @@ class PsDeclaration(PsAssignment): ...@@ -190,6 +203,9 @@ class PsDeclaration(PsAssignment):
class PsLoop(PsAstNode): class PsLoop(PsAstNode):
__match_args__ = ("counter", "start", "stop", "step", "body")
def __init__( def __init__(
self, self,
ctr: PsSymbolExpr, ctr: PsSymbolExpr,
......
from typing import TypeVar
T = TypeVar("T")
def failing_cast(target: type, obj: T):
if not isinstance(obj, target):
raise TypeError(f"Casting {obj} to {target} failed.")
return obj
...@@ -3,6 +3,7 @@ from __future__ import annotations ...@@ -3,6 +3,7 @@ from __future__ import annotations
from pymbolic.mapper.c_code import CCodeMapper from pymbolic.mapper.c_code import CCodeMapper
from .ast import ast_visitor, PsAstNode, PsBlock, PsExpression, PsDeclaration, PsAssignment, PsLoop from .ast import ast_visitor, PsAstNode, PsBlock, PsExpression, PsDeclaration, PsAssignment, PsLoop
from .ast.kernelfunction import PsKernelFunction
class CPrinter: class CPrinter:
...@@ -23,6 +24,14 @@ class CPrinter: ...@@ -23,6 +24,14 @@ class CPrinter:
@ast_visitor @ast_visitor
def visit(self, node: PsAstNode): def visit(self, node: PsAstNode):
raise ValueError("Cannot print this node.") raise ValueError("Cannot print this node.")
@visit.case(PsKernelFunction)
def function(self, func: PsKernelFunction) -> str:
params = func.get_parameters()
params_str = ", ".join(f"{p.dtype} {p.name}" for p in params)
decl = f"FUNC_PREFIX void {func.name} ( {params_str} )"
body = self.visit(func.body)
return f"{decl}\n{body}"
@visit.case(PsBlock) @visit.case(PsBlock)
def block(self, block: PsBlock): def block(self, block: PsBlock):
...@@ -30,7 +39,7 @@ class CPrinter: ...@@ -30,7 +39,7 @@ class CPrinter:
return self.indent("{ }") return self.indent("{ }")
self._current_indent_level += self._indent_width self._current_indent_level += self._indent_width
interior = "".join(self.visit(c) for c in block.children()) interior = "\n".join(self.visit(c) for c in block.children())
self._current_indent_level -= self._indent_width self._current_indent_level -= self._indent_width
return self.indent("{\n") + interior + self.indent("}\n") return self.indent("{\n") + interior + self.indent("}\n")
...@@ -40,11 +49,11 @@ class CPrinter: ...@@ -40,11 +49,11 @@ class CPrinter:
@visit.case(PsDeclaration) @visit.case(PsDeclaration)
def declaration(self, decl: PsDeclaration): def declaration(self, decl: PsDeclaration):
lhs_symb = decl.declared_symbol.symbol lhs_symb = decl.declared_variable.symbol
lhs_dtype = lhs_symb.dtype lhs_dtype = lhs_symb.dtype
rhs_code = self.visit(decl.rhs) rhs_code = self.visit(decl.rhs)
return self.indent(f"{lhs_dtype} {lhs_symb.name} = {rhs_code};\n") return self.indent(f"{lhs_dtype} {lhs_symb.name} = {rhs_code};")
@visit.case(PsAssignment) @visit.case(PsAssignment)
def assignment(self, asm: PsAssignment): def assignment(self, asm: PsAssignment):
...@@ -54,7 +63,7 @@ class CPrinter: ...@@ -54,7 +63,7 @@ class CPrinter:
@visit.case(PsLoop) @visit.case(PsLoop)
def loop(self, loop: PsLoop): def loop(self, loop: PsLoop):
ctr_symbol = loop.counter.expression ctr_symbol = loop.counter.symbol
ctr = ctr_symbol.name ctr = ctr_symbol.name
start_code = self.visit(loop.start) start_code = self.visit(loop.start)
stop_code = self.visit(loop.stop) stop_code = self.visit(loop.stop)
......
class PsInternalCompilerError(Exception):
pass
class PsMalformedAstException(Exception):
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment