Skip to content
Snippets Groups Projects
Commit e1c2463c authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fix: Lookups and Derefs are now Lvalues

parent 24b61fcf
Branches
Tags
No related merge requests found
Pipeline #64311 failed
...@@ -60,15 +60,11 @@ class PsExpression(PsAstNode, ABC): ...@@ -60,15 +60,11 @@ class PsExpression(PsAstNode, ABC):
pass pass
class PsLvalueExpr(PsExpression, ABC): class PsLvalue(ABC):
"""Base class for all expressions that may occur as an lvalue""" """Mix-in for all expressions that may occur as an lvalue"""
@abstractmethod
def clone(self) -> PsLvalueExpr:
pass
class PsSymbolExpr(PsLeafMixIn, PsLvalueExpr): class PsSymbolExpr(PsLeafMixIn, PsLvalue, PsExpression):
"""A single symbol as an expression.""" """A single symbol as an expression."""
__match_args__ = ("symbol",) __match_args__ = ("symbol",)
...@@ -124,7 +120,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): ...@@ -124,7 +120,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression):
return f"Constant({repr(self._constant)})" return f"Constant({repr(self._constant)})"
class PsSubscript(PsLvalueExpr): class PsSubscript(PsLvalue, PsExpression):
__match_args__ = ("base", "index") __match_args__ = ("base", "index")
def __init__(self, base: PsExpression, index: PsExpression): def __init__(self, base: PsExpression, index: PsExpression):
...@@ -271,7 +267,7 @@ class PsVectorArrayAccess(PsArrayAccess): ...@@ -271,7 +267,7 @@ class PsVectorArrayAccess(PsArrayAccess):
) )
class PsLookup(PsExpression): class PsLookup(PsExpression, PsLvalue):
__match_args__ = ("aggregate", "member_name") __match_args__ = ("aggregate", "member_name")
def __init__(self, aggregate: PsExpression, member_name: str) -> None: def __init__(self, aggregate: PsExpression, member_name: str) -> None:
...@@ -384,7 +380,7 @@ class PsNeg(PsUnOp): ...@@ -384,7 +380,7 @@ class PsNeg(PsUnOp):
return operator.neg return operator.neg
class PsDeref(PsUnOp): class PsDeref(PsLvalue, PsUnOp):
pass pass
......
...@@ -3,7 +3,7 @@ from typing import Sequence, cast ...@@ -3,7 +3,7 @@ from typing import Sequence, cast
from types import NoneType from types import NoneType
from .astnode import PsAstNode, PsLeafMixIn from .astnode import PsAstNode, PsLeafMixIn
from .expressions import PsExpression, PsLvalueExpr, PsSymbolExpr from .expressions import PsExpression, PsLvalue, PsSymbolExpr
from .util import failing_cast from .util import failing_cast
...@@ -76,16 +76,20 @@ class PsAssignment(PsAstNode): ...@@ -76,16 +76,20 @@ class PsAssignment(PsAstNode):
"rhs", "rhs",
) )
def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression): def __init__(self, lhs: PsExpression, rhs: PsExpression):
self._lhs = lhs if not isinstance(lhs, PsLvalue):
raise ValueError("Assignment LHS must be an lvalue")
self._lhs: PsExpression = lhs
self._rhs = rhs self._rhs = rhs
@property @property
def lhs(self) -> PsLvalueExpr: def lhs(self) -> PsExpression:
return self._lhs return self._lhs
@lhs.setter @lhs.setter
def lhs(self, lvalue: PsLvalueExpr): def lhs(self, lvalue: PsExpression):
if not isinstance(lvalue, PsLvalue):
raise ValueError("Assignment LHS must be an lvalue")
self._lhs = lvalue self._lhs = lvalue
@property @property
...@@ -105,7 +109,7 @@ class PsAssignment(PsAstNode): ...@@ -105,7 +109,7 @@ class PsAssignment(PsAstNode):
def set_child(self, idx: int, c: PsAstNode): def set_child(self, idx: int, c: PsAstNode):
idx = [0, 1][idx] # trick to normalize index idx = [0, 1][idx] # trick to normalize index
if idx == 0: if idx == 0:
self._lhs = failing_cast(PsLvalueExpr, c) self.lhs = failing_cast(PsExpression, c)
elif idx == 1: elif idx == 1:
self._rhs = failing_cast(PsExpression, c) self._rhs = failing_cast(PsExpression, c)
else: else:
...@@ -125,11 +129,11 @@ class PsDeclaration(PsAssignment): ...@@ -125,11 +129,11 @@ class PsDeclaration(PsAssignment):
super().__init__(lhs, rhs) super().__init__(lhs, rhs)
@property @property
def lhs(self) -> PsLvalueExpr: def lhs(self) -> PsExpression:
return self._lhs return self._lhs
@lhs.setter @lhs.setter
def lhs(self, lvalue: PsLvalueExpr): def lhs(self, lvalue: PsExpression):
self._lhs = failing_cast(PsSymbolExpr, lvalue) self._lhs = failing_cast(PsSymbolExpr, lvalue)
@property @property
...@@ -146,7 +150,7 @@ class PsDeclaration(PsAssignment): ...@@ -146,7 +150,7 @@ class PsDeclaration(PsAssignment):
def set_child(self, idx: int, c: PsAstNode): def set_child(self, idx: int, c: PsAstNode):
idx = [0, 1][idx] # trick to normalize index idx = [0, 1][idx] # trick to normalize index
if idx == 0: if idx == 0:
self._lhs = failing_cast(PsSymbolExpr, c) self.lhs = failing_cast(PsSymbolExpr, c)
elif idx == 1: elif idx == 1:
self._rhs = failing_cast(PsExpression, c) self._rhs = failing_cast(PsExpression, c)
else: else:
......
...@@ -132,10 +132,12 @@ class FreezeExpressions: ...@@ -132,10 +132,12 @@ class FreezeExpressions:
if isinstance(lhs, PsSymbolExpr): if isinstance(lhs, PsSymbolExpr):
return PsDeclaration(lhs, rhs) return PsDeclaration(lhs, rhs)
elif isinstance(lhs, (PsArrayAccess, PsVectorArrayAccess)): # todo elif isinstance(lhs, (PsArrayAccess, PsLookup, PsVectorArrayAccess)): # todo
return PsAssignment(lhs, rhs) return PsAssignment(lhs, rhs)
else: else:
assert False, "That should not have happened." raise FreezeError(
f"Encountered unsupported expression on assignment left-hand side: {lhs}"
)
def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr: def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr:
symb = self._ctx.get_symbol(spsym.name) symb = self._ctx.get_symbol(spsym.name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment