Skip to content
Snippets Groups Projects
Commit e3757c08 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Added AST prototype for backend rework

parent 40d83d2f
Branches
Tags
No related merge requests found
from .nodes import (
PsAstNode, PsBlock, PsExpression, PsLvalueExpr, PsSymbolExpr,
PsAssignment, PsDeclaration, PsLoop
)
from .dispatcher import ast_visitor
from .transformations import ast_subs
__all__ = [
ast_visitor,
PsAstNode, PsBlock, PsExpression, PsLvalueExpr, PsSymbolExpr, PsAssignment, PsDeclaration, PsLoop,
ast_subs
]
from __future__ import annotations
from typing import Callable
from types import MethodType
from functools import wraps
from .nodes import PsAstNode
class VisitorDispatcher:
def __init__(self, wrapped_method):
self._dispatch_dict = {}
self._wrapped_method = wrapped_method
def case(self, node_type: type):
"""Decorator for visitor's methods"""
def decorate(handler: Callable):
if node_type in self._dispatch_dict:
raise ValueError(f"Duplicate visitor case {node_type}")
self._dispatch_dict[node_type] = handler
return handler
return decorate
def __call__(self, instance, node: PsAstNode, *args, **kwargs):
for cls in node.__class__.mro():
if cls in self._dispatch_dict:
return self._dispatch_dict[cls](instance, node, *args, **kwargs)
return self._wrapped_method(instance, node, *args, **kwargs)
def __get__(self, obj, objtype=None):
if obj is None:
return self
return MethodType(self, obj)
def ast_visitor(method):
return wraps(method)(VisitorDispatcher(method))
from __future__ import annotations
from typing import Sequence, Generator
from abc import ABC
import pymbolic.primitives as pb
from ..typed_expressions import PsTypedSymbol, PsLvalue
class PsAstNode(ABC):
"""Base class for all nodes in the pystencils AST."""
def __init__(self, *children: Sequence[PsAstNode]):
for c in children:
if not isinstance(c, PsAstNode):
raise TypeError(f"Child {c} was not a PsAstNode.")
self._children = list(children)
@property
def children(self) -> Generator[PsAstNode, None, None]:
yield from self._children
def child(self, idx: int):
return self._children[idx]
@children.setter
def children(self, cs: Sequence[PsAstNode]):
if len(cs) != len(self._children):
raise ValueError("The number of child nodes must remain the same!")
self._children = list(cs)
def __getitem__(self, idx: int):
return self._children[idx]
def __setitem__(self, idx: int, c: PsAstNode):
self._children[idx] = c
class PsBlock(PsAstNode):
@PsAstNode.children.setter
def children(self, cs):
self._children = cs
class PsExpression(PsAstNode):
"""Wrapper around pymbolics expressions."""
def __init__(self, expr: pb.Expression):
super().__init__()
self._expr = expr
@property
def expression(self) -> pb.Expression:
return self._expr
@expression.setter
def expression(self, expr: pb.Expression):
self._expr = expr
class PsLvalueExpr(PsExpression):
"""Wrapper around pymbolics expressions that may occur at the left-hand side of an assignment"""
def __init__(self, expr: PsLvalue):
if not isinstance(expr, PsLvalue):
raise TypeError("Expression was not a valid lvalue")
super(PsLvalueExpr, self).__init__(expr)
class PsSymbolExpr(PsLvalueExpr):
"""Wrapper around PsTypedSymbols"""
def __init__(self, symbol: PsTypedSymbol):
if not isinstance(symbol, PsTypedSymbol):
raise TypeError("Not a symbol!")
super(PsLvalueExpr, self).__init__(symbol)
@property
def symbol(self) -> PsSymbolExpr:
return self.expression
@symbol.setter
def symbol(self, symbol: PsSymbolExpr):
self.expression = symbol
class PsAssignment(PsAstNode):
def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression):
super(PsAssignment, self).__init__(lhs, rhs)
@property
def lhs(self) -> PsLvalueExpr:
return self._children[0]
@lhs.setter
def lhs(self, lvalue: PsLvalueExpr):
self._children[0] = lvalue
@property
def rhs(self) -> PsExpression:
return self._children[1]
@rhs.setter
def rhs(self, expr: PsExpression):
self._children[1] = expr
class PsDeclaration(PsAssignment):
def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression):
super(PsDeclaration, self).__init__(lhs, rhs)
@property
def lhs(self) -> PsSymbolExpr:
return self._children[0]
@lhs.setter
def lhs(self, symbol_node: PsSymbolExpr):
self._children[0] = symbol_node
class PsLoop(PsAstNode):
def __init__(self,
ctr: PsSymbolExpr,
start: PsExpression,
stop: PsExpression,
step: PsExpression,
body: PsBlock):
super(PsLoop, self).__init__(ctr, start, stop, step, body)
@property
def counter(self) -> PsSymbolExpr:
return self._children[0]
@property
def start(self) -> PsExpression:
return self._children[1]
@start.setter
def start(self, expr: PsExpression):
self._children[1] = expr
@property
def stop(self) -> PsExpression:
return self._children[2]
@stop.setter
def stop(self, expr: PsExpression):
self._children[2] = expr
@property
def step(self) -> PsExpression:
return self._children[3]
@step.setter
def step(self, expr: PsExpression):
self._children[3] = expr
@property
def body(self) -> PsBlock:
return self._children[4]
@body.setter
def body(self, block: PsBlock):
self._children[4] = block
from abc import ABC
from typing import Dict
from pymbolic.primitives import Expression
from pymbolic.mapper.substitutor import CachedSubstitutionMapper, make_subst_func
from ..typed_expressions import PsTypedSymbol
from .dispatcher import ast_visitor
from .nodes import PsAstNode, PsAssignment, PsLoop, PsExpression
class PsAstTransformer(ABC):
def transform_children(self, node: PsAstNode, *args, **kwargs):
node.children = [self.visit(c, *args, **kwargs) for c in node.children]
@ast_visitor
def visit(self, node, *args, **kwargs):
self.transform_children(node, *args, **kwargs)
return node
class PsSymbolsSubstitutor(PsAstTransformer):
def __init__(self, subs_dict: Dict[PsTypedSymbol, Expression]):
self._subs_dict = subs_dict
self._mapper = CachedSubstitutionMapper(lambda s : self._subs_dict.get(s, None))
def subs(self, node: PsAstNode):
return self.visit(node)
visit = PsAstTransformer.visit
@visit.case(PsAssignment)
def assignment(self, asm: PsAssignment):
lhs_expr = asm.lhs.expression
if isinstance(lhs_expr, PsTypedSymbol) and lhs_expr in self._subs_dict:
raise ValueError(f"Cannot substitute symbol {lhs_expr} that occurs on a left-hand side of an assignment.")
self.transform_children(asm)
return asm
@visit.case(PsLoop)
def loop(self, loop: PsLoop):
if loop.counter.expression in self._subs_dict:
raise ValueError(f"Cannot substitute symbol {loop.counter.expression} that is defined as a loop counter.")
self.transform_children(loop)
return loop
@visit.case(PsExpression)
def expression(self, expr_node: PsExpression):
self._mapper(expr_node.expression)
def ast_subs(node: PsAstNode, subs_dict: dict):
return PsSymbolsSubstitutor(subs_dict).subs(node)
from __future__ import annotations
from pymbolic.mapper.c_code import CCodeMapper
from .ast import ast_visitor, PsAstNode, PsBlock, PsExpression, PsDeclaration, PsAssignment, PsLoop
class CPrinter:
def __init__(self, indent_width=3):
self._indent_width = indent_width
self._current_indent_level = 0
self._inside_expression = False # controls parentheses in nested arithmetic expressions
self._pb_cmapper = CCodeMapper()
def indent(self, line):
return " " * self._current_indent_level + line
def print(self, node: PsAstNode):
return self.visit(node)
@ast_visitor
def visit(self, node: PsAstNode):
raise ValueError("Cannot print this node.")
@visit.case(PsBlock)
def block(self, block: PsBlock):
if not block.children:
return self.indent("{ }")
self._current_indent_level += self._indent_width
interior = "".join(self.visit(c) for c in block.children)
self._current_indent_level -= self._indent_width
return self.indent("{\n") + interior + self.indent("}\n")
@visit.case(PsExpression)
def pymb_expression(self, expr: PsExpression):
return self._pb_cmapper(expr.expression)
@visit.case(PsDeclaration)
def declaration(self, decl: PsDeclaration):
lhs_symb = decl.lhs.symbol
lhs_dtype = lhs_symb.dtype
rhs_code = self.visit(decl.rhs)
return self.indent(f"{lhs_dtype} {lhs_symb.name} = {rhs_code};\n")
@visit.case(PsAssignment)
def assignment(self, asm: PsAssignment):
lhs_code = self.visit(asm.lhs)
rhs_code = self.visit(asm.rhs)
return self.indent(f"{lhs_code} = {rhs_code};\n")
@visit.case(PsLoop)
def loop(self, loop: PsLoop):
ctr_symbol = loop.counter.expression
ctr = ctr_symbol.name
start_code = self.visit(loop.start)
stop_code = self.visit(loop.stop)
step_code = self.visit(loop.step)
body_code = self.visit(loop.body)
code = f"for({ctr_symbol.dtype} {ctr} = {start_code};" + \
f" {ctr} < {stop_code};" + \
f" {ctr} += {step_code})\n" + \
body_code
return code
from __future__ import annotations
from typing import TypeAlias, Union, Any
import pymbolic.primitives as pb
from ..typing import AbstractType, BasicType
class PsTypedSymbol(pb.Variable):
def __init__(self, name: str, dtype: AbstractType):
super(PsTypedSymbol, self).__init__(name)
self._dtype = dtype
@property
def dtype(self) -> AbstractType:
return self._dtype
class PsArrayBasePointer(PsTypedSymbol):
def __init__(self, name: str, base_type: AbstractType):
super(PsArrayBasePointer, self).__init__(name, base_type)
class PsArrayAccess(pb.Subscript):
def __init__(self, base_ptr: PsArrayBasePointer, index: pb.Expression):
super(PsArrayAccess, self).__init__(base_ptr, index)
PsLvalue : TypeAlias = Union[PsTypedSymbol, PsArrayAccess]
class PsTypedConstant:
@staticmethod
def _cast(value, target_dtype: AbstractType):
if isinstance(value, PsTypedConstant):
if value._dtype != target_dtype:
raise ValueError(f"Incompatible types: {value._dtype} and {target_dtype}")
return value
# TODO check legality
return PsTypedConstant(value, target_dtype)
def __init__(self, value, dtype: AbstractType):
"""Represents typed constants occuring in the pystencils AST"""
if isinstance(dtype, BasicType):
dtype = BasicType(dtype, const = True)
self._value = dtype.numpy_dtype.type(value)
else:
raise ValueError(f"Cannot create constant of type {dtype}")
self._dtype = dtype
def __str__(self) -> str:
return str(self._value)
def __add__(self, other: Any):
other = PsTypedConstant._cast(other, self._dtype)
return PsTypedConstant(self._value + other._value, self._dtype)
def __mul__(self, other: Any):
other = PsTypedConstant._cast(other, self._dtype)
return PsTypedConstant(self._value * other._value, self._dtype)
def __sub__(self, other: Any):
other = PsTypedConstant._cast(other, self._dtype)
return PsTypedConstant(self._value - other._value, self._dtype)
# TODO: Remaining operators
pb.VALID_CONSTANT_CLASSES += (PsTypedConstant,)
......@@ -90,7 +90,7 @@ setuptools.setup(name='pystencils',
author_email='cs10-codegen@fau.de',
url='https://i10git.cs.fau.de/pycodegen/pystencils/',
packages=['pystencils'] + ['pystencils.' + s for s in setuptools.find_packages('pystencils')],
install_requires=['sympy>=1.6,<=1.11.1', 'numpy>=1.8.0', 'appdirs', 'joblib'],
install_requires=['sympy>=1.6,<=1.11.1', 'numpy>=1.8.0', 'pymbolic>=2022.2', 'appdirs', 'joblib'],
package_data={'pystencils': ['include/*.h',
'backends/cuda_known_functions.txt',
'backends/opencl1.1_known_functions.txt',
......@@ -131,6 +131,6 @@ setuptools.setup(name='pystencils',
'ipython',
'randomgen>=1.18'],
python_requires=">=3.8",
python_requires=">=3.10",
cmdclass=get_cmdclass()
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment