Skip to content
Snippets Groups Projects
Commit 07af649b authored by Frederik Hennig's avatar Frederik Hennig
Browse files

introduce AST cloning

parent b76caf99
Branches
Tags
No related merge requests found
Pipeline #63764 failed
......@@ -29,6 +29,10 @@ class PsAstNode(ABC):
def set_child(self, idx: int, c: PsAstNode):
pass
@abstractmethod
def clone(self) -> PsAstNode:
pass
def structurally_equal(self, other: PsAstNode) -> bool:
"""Check two ASTs for structural equality."""
return (
......
from __future__ import annotations
from abc import ABC
from abc import ABC, abstractmethod
from typing import Sequence, overload
from ..symbols import PsSymbol
......@@ -54,10 +54,18 @@ class PsExpression(PsAstNode, ABC):
else:
raise ValueError(f"Cannot make expression out of {obj}")
@abstractmethod
def clone(self) -> PsExpression:
pass
class PsLvalueExpr(PsExpression, ABC):
"""Base class for all expressions that may occur as an lvalue"""
@abstractmethod
def clone(self) -> PsLvalueExpr:
pass
class PsSymbolExpr(PsLeafMixIn, PsLvalueExpr):
"""A single symbol as an expression."""
......@@ -75,6 +83,9 @@ class PsSymbolExpr(PsLeafMixIn, PsLvalueExpr):
def symbol(self, symbol: PsSymbol):
self._symbol = symbol
def clone(self) -> PsSymbolExpr:
return PsSymbolExpr(self._symbol)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsSymbolExpr):
return False
......@@ -99,6 +110,9 @@ class PsConstantExpr(PsLeafMixIn, PsExpression):
def constant(self, c: PsConstant):
self._constant = c
def clone(self) -> PsConstantExpr:
return PsConstantExpr(self._constant)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsConstantExpr):
return False
......@@ -132,6 +146,9 @@ class PsSubscript(PsLvalueExpr):
def index(self, expr: PsExpression):
self._index = expr
def clone(self) -> PsSubscript:
return PsSubscript(self._base.clone(), self._index.clone())
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._base, self._index)
......@@ -180,6 +197,9 @@ class PsArrayAccess(PsSubscript):
"""Data type of this expression, i.e. the element type of the underlying array"""
return self._base_ptr.array.element_type
def clone(self) -> PsArrayAccess:
return PsArrayAccess(self._base_ptr, self._index.clone())
def __repr__(self) -> str:
return f"ArrayAccess({repr(self._base_ptr)}, {repr(self._index)})"
......@@ -226,6 +246,15 @@ class PsVectorArrayAccess(PsArrayAccess):
def alignment(self) -> int:
return self._alignment
def clone(self) -> PsVectorArrayAccess:
return PsVectorArrayAccess(
self._base_ptr,
self._index.clone(),
self.vector_entries,
self._stride,
self._alignment,
)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsVectorArrayAccess):
return False
......@@ -243,7 +272,7 @@ class PsLookup(PsExpression):
def __init__(self, aggregate: PsExpression, member_name: str) -> None:
self._aggregate = aggregate
self._member = member_name
self._member_name = member_name
@property
def aggregate(self) -> PsExpression:
......@@ -255,12 +284,15 @@ class PsLookup(PsExpression):
@property
def member_name(self) -> str:
return self._member
return self._member_name
@member_name.setter
def member_name(self, name: str):
self._name = name
def clone(self) -> PsLookup:
return PsLookup(self._aggregate.clone(), self._member_name)
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._aggregate,)
......@@ -298,6 +330,9 @@ class PsCall(PsExpression):
self._args = list(exprs)
def clone(self) -> PsCall:
return PsCall(self._function, [arg.clone() for arg in self._args])
def get_children(self) -> tuple[PsAstNode, ...]:
return self.args
......@@ -324,6 +359,9 @@ class PsUnOp(PsExpression):
def operand(self, expr: PsExpression):
self._operand = expr
def clone(self) -> PsUnOp:
return type(self)(self._operand.clone())
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._operand,)
......@@ -359,6 +397,9 @@ class PsCast(PsUnOp):
def target_type(self, dtype: PsType):
self._target_type = dtype
def clone(self) -> PsUnOp:
return PsCast(self._target_type, self._operand.clone())
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsCast):
return False
......@@ -391,6 +432,9 @@ class PsBinOp(PsExpression):
def operand2(self, expr: PsExpression):
self._op2 = expr
def clone(self) -> PsBinOp:
return type(self)(self._op1.clone(), self._op2.clone())
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._op1, self._op2)
......
......@@ -20,6 +20,9 @@ class PsBlock(PsAstNode):
def set_child(self, idx: int, c: PsAstNode):
self._statements[idx] = c
def clone(self) -> PsBlock:
return PsBlock([stmt.clone() for stmt in self._statements])
@property
def statements(self) -> list[PsAstNode]:
return self._statements
......@@ -47,6 +50,9 @@ class PsStatement(PsAstNode):
def expression(self, expr: PsExpression):
self._expression = expr
def clone(self) -> PsStatement:
return PsStatement(self._expression.clone())
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._expression,)
......@@ -82,6 +88,9 @@ class PsAssignment(PsAstNode):
def rhs(self, expr: PsExpression):
self._rhs = expr
def clone(self) -> PsAssignment:
return PsAssignment(self._lhs.clone(), self._rhs.clone())
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._lhs, self._rhs)
......@@ -123,6 +132,9 @@ class PsDeclaration(PsAssignment):
def declared_variable(self, lvalue: PsSymbolExpr):
self._lhs = lvalue
def clone(self) -> PsDeclaration:
return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone())
def set_child(self, idx: int, c: PsAstNode):
idx = [0, 1][idx] # trick to normalize index
if idx == 0:
......@@ -193,6 +205,15 @@ class PsLoop(PsAstNode):
def body(self, block: PsBlock):
self._body = block
def clone(self) -> PsLoop:
return PsLoop(
self._ctr.clone(),
self._start.clone(),
self._stop.clone(),
self._step.clone(),
self._body.clone(),
)
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._ctr, self._start, self._stop, self._step, self._body)
......@@ -252,6 +273,13 @@ class PsConditional(PsAstNode):
def branch_false(self, block: PsBlock | None):
self._branch_false = block
def clone(self) -> PsConditional:
return PsConditional(
self._condition.clone(),
self._branch_true.clone(),
self._branch_false.clone() if self._branch_false is not None else None,
)
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._condition, self._branch_true) + (
(self._branch_false,) if self._branch_false is not None else ()
......@@ -285,6 +313,9 @@ class PsComment(PsLeafMixIn, PsAstNode):
def lines(self) -> tuple[str, ...]:
return self._lines
def clone(self) -> PsComment:
return PsComment(self._text)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsComment):
return False
......
......@@ -117,7 +117,9 @@ class FreezeExpressions:
for summand in expr.args:
if summand.is_negative:
signs.append(-1)
elif isinstance(summand, sp.Mul) and any(factor.is_negative for factor in summand.args):
elif isinstance(summand, sp.Mul) and any(
factor.is_negative for factor in summand.args
):
signs.append(-1)
else:
signs.append(1)
......@@ -126,18 +128,18 @@ class FreezeExpressions:
for sign, arg in zip(signs[1:], expr.args[1:]):
if sign == -1:
arg = - arg
arg = -arg
op = sub
else:
op = add
frozen_expr = op(frozen_expr, self.visit_expr(arg))
return frozen_expr
def map_Mul(self, expr: sp.Mul) -> PsExpression:
return reduce(mul, (self.visit_expr(arg) for arg in expr.args))
def map_Pow(self, expr: sp.Pow) -> PsExpression:
base = expr.args[0]
exponent = expr.args[1]
......@@ -147,18 +149,29 @@ class FreezeExpressions:
expand_product = False
if exponent.is_Integer:
if exponent == 0:
return PsExpression.make(PsConstant(1))
if exponent.is_negative:
reciprocal = True
exponent = - exponent
exponent = -exponent
if exponent <= sp.Integer(5):
if exponent <= sp.Integer(
5
): # TODO: is this a sensible limit? maybe make this configurable.
expand_product = True
if expand_product:
frozen_expr = reduce(mul, [base_frozen] * int(exponent))
frozen_expr = reduce(
mul,
[base_frozen]
+ [base_frozen.clone() for _ in range(0, int(exponent) - 1)],
)
else:
exponent_frozen = self.visit_expr(exponent)
frozen_expr = PsMathFunction(MathFunctions.Pow)(base_frozen, exponent_frozen)
frozen_expr = PsMathFunction(MathFunctions.Pow)(
base_frozen, exponent_frozen
)
if reciprocal:
one = PsExpression.make(PsConstant(1))
......
......@@ -42,10 +42,10 @@ class KernelParameter:
type(self) is type(other)
and self._hashable_contents() == other._hashable_contents()
)
def __str__(self) -> str:
return self._name
def __repr__(self) -> str:
return f"{type(self).__name__}(name = {self._name}, dtype = {self._dtype})"
......@@ -60,7 +60,7 @@ class FieldParameter(KernelParameter, ABC):
@property
def field(self):
return self._field
def _hashable_contents(self):
return super()._hashable_contents() + (self._field,)
......@@ -75,7 +75,7 @@ class FieldShapeParam(FieldParameter):
@property
def coordinate(self):
return self._coordinate
def _hashable_contents(self):
return super()._hashable_contents() + (self._coordinate,)
......@@ -90,7 +90,7 @@ class FieldStrideParam(FieldParameter):
@property
def coordinate(self):
return self._coordinate
def _hashable_contents(self):
return super()._hashable_contents() + (self._coordinate,)
......
......@@ -4,7 +4,7 @@ import uuid
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
import sympy as sp
from sympy.codegen.ast import Assignment, AugmentedAssignment
from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment
from sympy.printing.latex import LatexPrinter
import numpy as np
......
from pystencils.backend.symbols import PsSymbol
from pystencils.backend.constants import PsConstant
from pystencils.backend.ast.expressions import (
PsExpression,
PsCast,
PsDeref,
PsSubscript,
)
from pystencils.backend.ast.structural import (
PsStatement,
PsAssignment,
PsBlock,
PsConditional,
PsComment,
PsLoop,
)
from pystencils.types.quick import Fp, Ptr
def test_cloning():
x, y, z = [PsExpression.make(PsSymbol(name)) for name in "xyz"]
c1 = PsExpression.make(PsConstant(3.0))
c2 = PsExpression.make(PsConstant(-1.0))
one = PsExpression.make(PsConstant(1))
def check(orig, clone):
assert not (orig is clone)
assert type(orig) is type(clone)
assert orig.structurally_equal(clone)
for c1, c2 in zip(orig.children, clone.children, strict=True):
check(c1, c2)
for ast in [
x,
y,
c1,
x + y,
x / y + c1,
c1 + c2,
PsStatement(x * y * z + c1),
PsAssignment(y, x / c1),
PsBlock([PsAssignment(x, c1 * y), PsAssignment(z, c2 + c1 * z)]),
PsConditional(
y, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")])
),
PsLoop(
x,
y,
z,
one,
PsBlock(
[
PsComment("Loop body"),
PsAssignment(x, y),
PsAssignment(x, y),
PsStatement(
PsDeref(PsCast(Ptr(Fp(32)), z))
+ PsSubscript(z, one + one + one)
),
]
),
),
]:
ast_clone = ast.clone()
check(ast, ast_clone)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment