diff --git a/pystencils/nbackend/__init__.py b/pystencils/nbackend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pystencils/nbackend/ast/__init__.py b/pystencils/nbackend/ast/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d8d4c33b8709b7cc2cf50369b48bb88b2e5bedca --- /dev/null +++ b/pystencils/nbackend/ast/__init__.py @@ -0,0 +1,13 @@ +from .nodes import ( + PsAstNode, PsBlock, PsExpression, PsLvalueExpr, PsSymbolExpr, + PsAssignment, PsDeclaration, PsLoop +) + +from .dispatcher import ast_visitor +from .transformations import ast_subs + +__all__ = [ + ast_visitor, + PsAstNode, PsBlock, PsExpression, PsLvalueExpr, PsSymbolExpr, PsAssignment, PsDeclaration, PsLoop, + ast_subs +] diff --git a/pystencils/nbackend/ast/dispatcher.py b/pystencils/nbackend/ast/dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..a27f41bcf51ef68ccb1ef9c11860f1845e5ae311 --- /dev/null +++ b/pystencils/nbackend/ast/dispatcher.py @@ -0,0 +1,40 @@ +from __future__ import annotations +from typing import Callable +from types import MethodType + +from functools import wraps + +from .nodes import PsAstNode + +class VisitorDispatcher: + def __init__(self, wrapped_method): + self._dispatch_dict = {} + self._wrapped_method = wrapped_method + + def case(self, node_type: type): + """Decorator for visitor's methods""" + + def decorate(handler: Callable): + if node_type in self._dispatch_dict: + raise ValueError(f"Duplicate visitor case {node_type}") + self._dispatch_dict[node_type] = handler + return handler + + return decorate + + def __call__(self, instance, node: PsAstNode, *args, **kwargs): + for cls in node.__class__.mro(): + if cls in self._dispatch_dict: + return self._dispatch_dict[cls](instance, node, *args, **kwargs) + + return self._wrapped_method(instance, node, *args, **kwargs) + + def __get__(self, obj, objtype=None): + if obj is None: + return self + return MethodType(self, obj) + + +def ast_visitor(method): + return wraps(method)(VisitorDispatcher(method)) + diff --git a/pystencils/nbackend/ast/nodes.py b/pystencils/nbackend/ast/nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..09774f90367d627a0c8e60b40215cf1d57368187 --- /dev/null +++ b/pystencils/nbackend/ast/nodes.py @@ -0,0 +1,167 @@ +from __future__ import annotations +from typing import Sequence, Generator + +from abc import ABC + +import pymbolic.primitives as pb + +from ..typed_expressions import PsTypedSymbol, PsLvalue + + +class PsAstNode(ABC): + """Base class for all nodes in the pystencils AST.""" + + def __init__(self, *children: Sequence[PsAstNode]): + for c in children: + if not isinstance(c, PsAstNode): + raise TypeError(f"Child {c} was not a PsAstNode.") + self._children = list(children) + + @property + def children(self) -> Generator[PsAstNode, None, None]: + yield from self._children + + def child(self, idx: int): + return self._children[idx] + + @children.setter + def children(self, cs: Sequence[PsAstNode]): + if len(cs) != len(self._children): + raise ValueError("The number of child nodes must remain the same!") + self._children = list(cs) + + def __getitem__(self, idx: int): + return self._children[idx] + + def __setitem__(self, idx: int, c: PsAstNode): + self._children[idx] = c + + +class PsBlock(PsAstNode): + @PsAstNode.children.setter + def children(self, cs): + self._children = cs + + +class PsExpression(PsAstNode): + """Wrapper around pymbolics expressions.""" + + def __init__(self, expr: pb.Expression): + super().__init__() + self._expr = expr + + @property + def expression(self) -> pb.Expression: + return self._expr + + @expression.setter + def expression(self, expr: pb.Expression): + self._expr = expr + + +class PsLvalueExpr(PsExpression): + """Wrapper around pymbolics expressions that may occur at the left-hand side of an assignment""" + + def __init__(self, expr: PsLvalue): + if not isinstance(expr, PsLvalue): + raise TypeError("Expression was not a valid lvalue") + + super(PsLvalueExpr, self).__init__(expr) + + +class PsSymbolExpr(PsLvalueExpr): + """Wrapper around PsTypedSymbols""" + + def __init__(self, symbol: PsTypedSymbol): + if not isinstance(symbol, PsTypedSymbol): + raise TypeError("Not a symbol!") + + super(PsLvalueExpr, self).__init__(symbol) + + @property + def symbol(self) -> PsSymbolExpr: + return self.expression + + @symbol.setter + def symbol(self, symbol: PsSymbolExpr): + self.expression = symbol + + +class PsAssignment(PsAstNode): + def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression): + super(PsAssignment, self).__init__(lhs, rhs) + + @property + def lhs(self) -> PsLvalueExpr: + return self._children[0] + + @lhs.setter + def lhs(self, lvalue: PsLvalueExpr): + self._children[0] = lvalue + + @property + def rhs(self) -> PsExpression: + return self._children[1] + + @rhs.setter + def rhs(self, expr: PsExpression): + self._children[1] = expr + + +class PsDeclaration(PsAssignment): + def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression): + super(PsDeclaration, self).__init__(lhs, rhs) + + @property + def lhs(self) -> PsSymbolExpr: + return self._children[0] + + @lhs.setter + def lhs(self, symbol_node: PsSymbolExpr): + self._children[0] = symbol_node + + +class PsLoop(PsAstNode): + def __init__(self, + ctr: PsSymbolExpr, + start: PsExpression, + stop: PsExpression, + step: PsExpression, + body: PsBlock): + super(PsLoop, self).__init__(ctr, start, stop, step, body) + + @property + def counter(self) -> PsSymbolExpr: + return self._children[0] + + @property + def start(self) -> PsExpression: + return self._children[1] + + @start.setter + def start(self, expr: PsExpression): + self._children[1] = expr + + @property + def stop(self) -> PsExpression: + return self._children[2] + + @stop.setter + def stop(self, expr: PsExpression): + self._children[2] = expr + + @property + def step(self) -> PsExpression: + return self._children[3] + + @step.setter + def step(self, expr: PsExpression): + self._children[3] = expr + + @property + def body(self) -> PsBlock: + return self._children[4] + + @body.setter + def body(self, block: PsBlock): + self._children[4] = block diff --git a/pystencils/nbackend/ast/transformations.py b/pystencils/nbackend/ast/transformations.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf23c0c7d3326bc8eb2041b6117ee760b06e375 --- /dev/null +++ b/pystencils/nbackend/ast/transformations.py @@ -0,0 +1,54 @@ +from abc import ABC + +from typing import Dict + +from pymbolic.primitives import Expression +from pymbolic.mapper.substitutor import CachedSubstitutionMapper, make_subst_func + +from ..typed_expressions import PsTypedSymbol +from .dispatcher import ast_visitor +from .nodes import PsAstNode, PsAssignment, PsLoop, PsExpression + + +class PsAstTransformer(ABC): + def transform_children(self, node: PsAstNode, *args, **kwargs): + node.children = [self.visit(c, *args, **kwargs) for c in node.children] + + @ast_visitor + def visit(self, node, *args, **kwargs): + self.transform_children(node, *args, **kwargs) + return node + + +class PsSymbolsSubstitutor(PsAstTransformer): + def __init__(self, subs_dict: Dict[PsTypedSymbol, Expression]): + self._subs_dict = subs_dict + self._mapper = CachedSubstitutionMapper(lambda s : self._subs_dict.get(s, None)) + + def subs(self, node: PsAstNode): + return self.visit(node) + + visit = PsAstTransformer.visit + + @visit.case(PsAssignment) + def assignment(self, asm: PsAssignment): + lhs_expr = asm.lhs.expression + if isinstance(lhs_expr, PsTypedSymbol) and lhs_expr in self._subs_dict: + raise ValueError(f"Cannot substitute symbol {lhs_expr} that occurs on a left-hand side of an assignment.") + self.transform_children(asm) + return asm + + @visit.case(PsLoop) + def loop(self, loop: PsLoop): + if loop.counter.expression in self._subs_dict: + raise ValueError(f"Cannot substitute symbol {loop.counter.expression} that is defined as a loop counter.") + self.transform_children(loop) + return loop + + @visit.case(PsExpression) + def expression(self, expr_node: PsExpression): + self._mapper(expr_node.expression) + + +def ast_subs(node: PsAstNode, subs_dict: dict): + return PsSymbolsSubstitutor(subs_dict).subs(node) diff --git a/pystencils/nbackend/c_printer.py b/pystencils/nbackend/c_printer.py new file mode 100644 index 0000000000000000000000000000000000000000..d920fa9dc776c2cd683d5378dbf9ae10aa00ed92 --- /dev/null +++ b/pystencils/nbackend/c_printer.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from pymbolic.mapper.c_code import CCodeMapper + +from .ast import ast_visitor, PsAstNode, PsBlock, PsExpression, PsDeclaration, PsAssignment, PsLoop + + +class CPrinter: + def __init__(self, indent_width=3): + self._indent_width = indent_width + + self._current_indent_level = 0 + self._inside_expression = False # controls parentheses in nested arithmetic expressions + + self._pb_cmapper = CCodeMapper() + + def indent(self, line): + return " " * self._current_indent_level + line + + def print(self, node: PsAstNode): + return self.visit(node) + + @ast_visitor + def visit(self, node: PsAstNode): + raise ValueError("Cannot print this node.") + + @visit.case(PsBlock) + def block(self, block: PsBlock): + if not block.children: + return self.indent("{ }") + + self._current_indent_level += self._indent_width + interior = "".join(self.visit(c) for c in block.children) + self._current_indent_level -= self._indent_width + return self.indent("{\n") + interior + self.indent("}\n") + + @visit.case(PsExpression) + def pymb_expression(self, expr: PsExpression): + return self._pb_cmapper(expr.expression) + + @visit.case(PsDeclaration) + def declaration(self, decl: PsDeclaration): + lhs_symb = decl.lhs.symbol + lhs_dtype = lhs_symb.dtype + rhs_code = self.visit(decl.rhs) + + return self.indent(f"{lhs_dtype} {lhs_symb.name} = {rhs_code};\n") + + @visit.case(PsAssignment) + def assignment(self, asm: PsAssignment): + lhs_code = self.visit(asm.lhs) + rhs_code = self.visit(asm.rhs) + return self.indent(f"{lhs_code} = {rhs_code};\n") + + @visit.case(PsLoop) + def loop(self, loop: PsLoop): + ctr_symbol = loop.counter.expression + ctr = ctr_symbol.name + start_code = self.visit(loop.start) + stop_code = self.visit(loop.stop) + step_code = self.visit(loop.step) + + body_code = self.visit(loop.body) + + code = f"for({ctr_symbol.dtype} {ctr} = {start_code};" + \ + f" {ctr} < {stop_code};" + \ + f" {ctr} += {step_code})\n" + \ + body_code + return code diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..055a0b904f365b7a5c4dc9cec15401a18d13cf42 --- /dev/null +++ b/pystencils/nbackend/typed_expressions.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import TypeAlias, Union, Any + +import pymbolic.primitives as pb + +from ..typing import AbstractType, BasicType + +class PsTypedSymbol(pb.Variable): + def __init__(self, name: str, dtype: AbstractType): + super(PsTypedSymbol, self).__init__(name) + self._dtype = dtype + + @property + def dtype(self) -> AbstractType: + return self._dtype + + +class PsArrayBasePointer(PsTypedSymbol): + def __init__(self, name: str, base_type: AbstractType): + super(PsArrayBasePointer, self).__init__(name, base_type) + + +class PsArrayAccess(pb.Subscript): + def __init__(self, base_ptr: PsArrayBasePointer, index: pb.Expression): + super(PsArrayAccess, self).__init__(base_ptr, index) + + +PsLvalue : TypeAlias = Union[PsTypedSymbol, PsArrayAccess] + + +class PsTypedConstant: + + @staticmethod + def _cast(value, target_dtype: AbstractType): + if isinstance(value, PsTypedConstant): + if value._dtype != target_dtype: + raise ValueError(f"Incompatible types: {value._dtype} and {target_dtype}") + return value + + # TODO check legality + return PsTypedConstant(value, target_dtype) + + def __init__(self, value, dtype: AbstractType): + """Represents typed constants occuring in the pystencils AST""" + if isinstance(dtype, BasicType): + dtype = BasicType(dtype, const = True) + self._value = dtype.numpy_dtype.type(value) + else: + raise ValueError(f"Cannot create constant of type {dtype}") + + self._dtype = dtype + + def __str__(self) -> str: + return str(self._value) + + def __add__(self, other: Any): + other = PsTypedConstant._cast(other, self._dtype) + + return PsTypedConstant(self._value + other._value, self._dtype) + + def __mul__(self, other: Any): + other = PsTypedConstant._cast(other, self._dtype) + + return PsTypedConstant(self._value * other._value, self._dtype) + + def __sub__(self, other: Any): + other = PsTypedConstant._cast(other, self._dtype) + + return PsTypedConstant(self._value - other._value, self._dtype) + + # TODO: Remaining operators + + +pb.VALID_CONSTANT_CLASSES += (PsTypedConstant,) diff --git a/setup.py b/setup.py index 31392a747a564de22b4224e1d02e6ad1078fc1aa..fb6a5aeea28f360369e040e84b4d1bb5ef2a8993 100644 --- a/setup.py +++ b/setup.py @@ -90,7 +90,7 @@ setuptools.setup(name='pystencils', author_email='cs10-codegen@fau.de', url='https://i10git.cs.fau.de/pycodegen/pystencils/', packages=['pystencils'] + ['pystencils.' + s for s in setuptools.find_packages('pystencils')], - install_requires=['sympy>=1.6,<=1.11.1', 'numpy>=1.8.0', 'appdirs', 'joblib'], + install_requires=['sympy>=1.6,<=1.11.1', 'numpy>=1.8.0', 'pymbolic>=2022.2', 'appdirs', 'joblib'], package_data={'pystencils': ['include/*.h', 'backends/cuda_known_functions.txt', 'backends/opencl1.1_known_functions.txt', @@ -131,6 +131,6 @@ setuptools.setup(name='pystencils', 'ipython', 'randomgen>=1.18'], - python_requires=">=3.8", + python_requires=">=3.10", cmdclass=get_cmdclass() )