Skip to content
Snippets Groups Projects

Fix handling of constness in Typifier

Merged Frederik Hennig requested to merge fhennig/fix-const-typing into backend-rework
All threads resolved!
Viewing commit afdd0127
Next
Show latest version
3 files
+ 155
31
Preferences
Compare changes
Files
3
  • afdd0127
    Fix constness in typifier: · afdd0127
    Frederik Hennig authored
     - TypeContext now assumes `const` by default
     - Introduce `require_nonconst` to TypeContext
     - Check LHS for constness in `PsAssignment`s
     - Fix test cases
@@ -12,7 +12,7 @@ from ...types import (
@@ -12,7 +12,7 @@ from ...types import (
PsDereferencableType,
PsDereferencableType,
PsPointerType,
PsPointerType,
PsBoolType,
PsBoolType,
deconstify,
constify,
)
)
from ..ast.structural import (
from ..ast.structural import (
PsAstNode,
PsAstNode,
@@ -21,6 +21,7 @@ from ..ast.structural import (
@@ -21,6 +21,7 @@ from ..ast.structural import (
PsConditional,
PsConditional,
PsExpression,
PsExpression,
PsAssignment,
PsAssignment,
 
PsDeclaration,
)
)
from ..ast.expressions import (
from ..ast.expressions import (
PsArrayAccess,
PsArrayAccess,
@@ -54,10 +55,16 @@ NodeT = TypeVar("NodeT", bound=PsAstNode)
@@ -54,10 +55,16 @@ NodeT = TypeVar("NodeT", bound=PsAstNode)
class TypeContext:
class TypeContext:
def __init__(self, target_type: PsType | None = None):
def __init__(
self._target_type = deconstify(target_type) if target_type is not None else None
self, target_type: PsType | None = None, require_nonconst: bool = False
 
):
 
self._require_nonconst = require_nonconst
self._deferred_exprs: list[PsExpression] = []
self._deferred_exprs: list[PsExpression] = []
 
self._target_type = (
 
self._fix_constness(target_type) if target_type is not None else None
 
)
 
def apply_dtype(self, expr: PsExpression | None, dtype: PsType):
def apply_dtype(self, expr: PsExpression | None, dtype: PsType):
"""Applies the given ``dtype`` to the given expression inside this type context.
"""Applies the given ``dtype`` to the given expression inside this type context.
@@ -67,7 +74,7 @@ class TypeContext:
@@ -67,7 +74,7 @@ class TypeContext:
to all deferred expressions.
to all deferred expressions.
"""
"""
dtype = deconstify(dtype)
dtype = self._fix_constness(dtype)
if self._target_type is not None and dtype != self._target_type:
if self._target_type is not None and dtype != self._target_type:
raise TypificationError(
raise TypificationError(
@@ -80,14 +87,7 @@ class TypeContext:
@@ -80,14 +87,7 @@ class TypeContext:
self._propagate_target_type()
self._propagate_target_type()
if expr is not None:
if expr is not None:
if expr.dtype is None:
self._apply_target_type(expr)
self._apply_target_type(expr)
elif deconstify(expr.dtype) != self._target_type:
raise TypificationError(
"Type conflict: Predefined expression type did not match the context's target type\n"
f" Expression type: {dtype}\n"
f" Target type: {self._target_type}"
)
def infer_dtype(self, expr: PsExpression):
def infer_dtype(self, expr: PsExpression):
"""Infer the data type for the given expression.
"""Infer the data type for the given expression.
@@ -113,7 +113,7 @@ class TypeContext:
@@ -113,7 +113,7 @@ class TypeContext:
assert self._target_type is not None
assert self._target_type is not None
if expr.dtype is not None:
if expr.dtype is not None:
if deconstify(expr.dtype) != self.target_type:
if not self._compatible(expr.dtype):
raise TypificationError(
raise TypificationError(
f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
f" Expression type: {expr.dtype}\n"
f" Expression type: {expr.dtype}\n"
@@ -128,7 +128,7 @@ class TypeContext:
@@ -128,7 +128,7 @@ class TypeContext:
)
)
if c.dtype is None:
if c.dtype is None:
expr.constant = c.interpret_as(self._target_type)
expr.constant = c.interpret_as(self._target_type)
elif deconstify(c.dtype) != self._target_type:
elif not self._compatible(c.dtype):
raise TypificationError(
raise TypificationError(
f"Type mismatch at constant {c}: Constant type did not match the context's target type\n"
f"Type mismatch at constant {c}: Constant type did not match the context's target type\n"
f" Constant type: {c.dtype}\n"
f" Constant type: {c.dtype}\n"
@@ -136,7 +136,14 @@ class TypeContext:
@@ -136,7 +136,14 @@ class TypeContext:
)
)
case PsSymbolExpr(symb):
case PsSymbolExpr(symb):
symb.apply_dtype(self._target_type)
if symb.dtype is None:
 
symb.dtype = self._target_type
 
elif not self._compatible(symb.dtype):
 
raise TypificationError(
 
f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n"
 
f" Symbol type: {symb.dtype}\n"
 
f" Target type: {self._target_type}"
 
)
case (
case (
PsIntDiv()
PsIntDiv()
@@ -151,9 +158,30 @@ class TypeContext:
@@ -151,9 +158,30 @@ class TypeContext:
f" Expression: {expr}"
f" Expression: {expr}"
f" Type Context: {self._target_type}"
f" Type Context: {self._target_type}"
)
)
expr.dtype = self._target_type
# endif
# endif
 
expr.dtype = self._target_type
 
 
def _compatible(self, dtype: PsType):
 
assert self._target_type is not None
 
if self._target_type.const:
 
return constify(dtype) == self._target_type
 
else:
 
return dtype == self._target_type
 
 
def _fix_constness(self, dtype: PsType, expr: PsExpression | None = None):
 
if self._require_nonconst:
 
if dtype.const:
 
if expr is None:
 
raise TypificationError(
 
f"Type mismatch: Encountered {dtype} in non-constant context."
 
)
 
else:
 
raise TypificationError(
 
f"Type mismatch at expression {expr}: Encountered {dtype} in non-constant context."
 
)
 
return dtype
 
else:
 
return constify(dtype)
@property
@property
def target_type(self) -> PsType | None:
def target_type(self) -> PsType | None:
@@ -213,13 +241,21 @@ class Typifier:
@@ -213,13 +241,21 @@ class Typifier:
for s in statements:
for s in statements:
self.visit(s)
self.visit(s)
case PsAssignment(lhs, rhs):
case PsDeclaration(lhs, rhs):
tc = TypeContext()
tc = TypeContext()
# LHS defines target type; type context carries it to RHS
# LHS defines target type; type context carries it to RHS
self.visit_expr(lhs, tc)
self.visit_expr(lhs, tc)
assert tc.target_type is not None
assert tc.target_type is not None
self.visit_expr(rhs, tc)
self.visit_expr(rhs, tc)
 
case PsAssignment(lhs, rhs):
 
tc_lhs = TypeContext(require_nonconst=True)
 
self.visit_expr(lhs, tc_lhs)
 
assert tc_lhs.target_type is not None
 
 
tc_rhs = TypeContext(tc_lhs.target_type, require_nonconst=False)
 
self.visit_expr(rhs, tc_rhs)
 
case PsConditional(cond, branch_true, branch_false):
case PsConditional(cond, branch_true, branch_false):
cond_tc = TypeContext(PsBoolType(const=True))
cond_tc = TypeContext(PsBoolType(const=True))
self.visit_expr(cond, cond_tc)
self.visit_expr(cond, cond_tc)
@@ -233,10 +269,10 @@ class Typifier:
@@ -233,10 +269,10 @@ class Typifier:
if ctr.symbol.dtype is None:
if ctr.symbol.dtype is None:
ctr.symbol.apply_dtype(self._ctx.index_dtype)
ctr.symbol.apply_dtype(self._ctx.index_dtype)
tc = TypeContext(ctr.symbol.dtype)
tc_lhs = TypeContext(ctr.symbol.dtype)
self.visit_expr(start, tc)
self.visit_expr(start, tc_lhs)
self.visit_expr(stop, tc)
self.visit_expr(stop, tc_lhs)
self.visit_expr(step, tc)
self.visit_expr(step, tc_lhs)
self.visit(body)
self.visit(body)
@@ -247,10 +283,10 @@ class Typifier:
@@ -247,10 +283,10 @@ class Typifier:
"""Recursive processing of expression nodes"""
"""Recursive processing of expression nodes"""
match expr:
match expr:
case PsSymbolExpr(_):
case PsSymbolExpr(_):
if expr.dtype is None:
if expr.symbol.dtype is None:
tc.apply_dtype(expr, self._ctx.default_dtype)
tc.apply_dtype(expr, self._ctx.default_dtype)
else:
else:
tc.apply_dtype(expr, expr.dtype)
tc.apply_dtype(expr, expr.symbol.dtype)
case PsConstantExpr(_):
case PsConstantExpr(_):
tc.infer_dtype(expr)
tc.infer_dtype(expr)
@@ -325,8 +361,12 @@ class Typifier:
@@ -325,8 +361,12 @@ class Typifier:
raise TypificationError(
raise TypificationError(
f"Aggregate of type {aggr_type} does not have a member {member}."
f"Aggregate of type {aggr_type} does not have a member {member}."
)
)
 
 
member_type = member.dtype
 
if aggr_type.const:
 
member_type = constify(member_type)
tc.apply_dtype(expr, member.dtype)
tc.apply_dtype(expr, member_type)
case PsBinOp(op1, op2):
case PsBinOp(op1, op2):
self.visit_expr(op1, tc)
self.visit_expr(op1, tc)