Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (16)
Showing
with 415 additions and 57 deletions
......@@ -326,9 +326,10 @@ nbackend-unit-tests:
before_script:
- pip install -e .[tests]
script:
- pytest tests/nbackend
- pytest tests/nbackend tests/symbolics
tags:
- docker
- cuda11
doctest:
stage: "Unit Tests"
......
......@@ -7,7 +7,6 @@ import pathlib
import nbformat
import pytest
from nbconvert import PythonExporter
# Trigger config file reading / creation once - to avoid race conditions when multiple instances are creating it
# at the same time
......@@ -134,6 +133,7 @@ class IPyNbTest(pytest.Item):
class IPyNbFile(pytest.File):
def collect(self):
from nbconvert import PythonExporter
exporter = PythonExporter()
exporter.exclude_markdown = True
exporter.exclude_input_prompt = True
......
......@@ -5,6 +5,6 @@ API Reference
.. toctree::
:maxdepth: 2
sympyextensions/index
symbolic_language/index
kernelcreation/index
types
-------------------------
Fields (pystencils.field)
-------------------------
-----------------------------
Fields API (pystencils.field)
-----------------------------
.. automodule:: pystencils.field
:members:
......
*****************
Symbolic Language
*****************
.. toctree::
:maxdepth: 1
field
astnodes
sympyextensions
Pystencils allows you to define near-arbitrarily complex numerical kernels in its symbolic
language, which is based on the computer algebra system `SymPy <https://www.sympy.org>`_.
The pystencils code generator is able to parse and translate a large portion of SymPy's
symbolic expression toolkit, and furthermore extends it with its own features.
Among the supported SymPy features are: symbols, constants, arithmetic and logical expressions,
trigonometric and most transcendental functions, as well as piecewise definitions.
Fields
======
The most important extension to SymPy brought by pystencils are *fields*.
Fields are a symbolic representation of multidimensional cartesian numerical arrays,
as used in many stencil algorithms.
They are represented by the `Field` class.
Piecewise Definitions
=====================
Pystencils can parse and translate piecewise function definitions using `sympy.Piecewise`
*only if* they have a default case.
So, for instance,
.. code-block:: Python
sp.Piecewise((0, x < 0), (1, x >= 0))
will result in an error from pystencils, while the equivalent
.. code-block:: Python
sp.Piecewise((0, x < 0), (1, True))
will be accepted. This is because pystencils cannot reason about whether or not
the given cases completely cover the entire possible input range.
Integer Operations
==================
Division and Remainder
----------------------
Care has to be taken when working with integer division operations in pystencils.
The python operators ``//`` and ``%`` work differently from their counterparts in the C family of languages.
Where in C, integer division always rounds toward zero, ``//`` performs a floor-divide (or euclidean division)
which rounds toward negative infinity.
These two operations differ whenever one of the operands is negative.
Accordingly, in Python ``a % b`` returns the *euclidean modulus*,
while C ``a % b`` computes the *remainder* of division.
The euclidean modulus is always nonnegative, while the remainder, if nonzero, always has the same sign as ``a``.
When ``//`` and ``%`` occur in symbolic expressions given to pystencils, they are interpreted the Python-way.
This can lead to inefficient generated code, since Pythonic integer division does not map to the corresponding C
operators.
To achieve C behaviour (and efficient code), you can use `pystencils.symb.int_div` and `pystencils.symb.int_rem`
which translate to C ``/`` and ``%``, respectively.
When expressions are translated in an integer type context, the Python ``/`` operator (or `sympy.Div`)
will also be converted to C-style ``/`` integer division.
Still, use of ``/`` for integers is discouraged, as it is designed to return a floating-point value in Python.
-------------------
Extensions to SymPy
-------------------
.. automodule:: pystencils.symb
:members:
*****************
Symbolic Language
*****************
.. toctree::
:maxdepth: 1
field
astnodes
......@@ -16,3 +16,6 @@ ignore_missing_imports=true
[mypy-appdirs.*]
ignore_missing_imports=true
[mypy-islpy.*]
ignore_missing_imports=true
......@@ -6,8 +6,14 @@ from . import fd
from . import stencil as stencil
from .display_utils import get_code_obj, get_code_str, show_code, to_dot
from .field import Field, FieldType, fields
from .types import create_type
from .cache import clear_cache
from .config import CreateKernelConfig, CpuOptimConfig, VectorizationConfig
from .config import (
CreateKernelConfig,
CpuOptimConfig,
VectorizationConfig,
OpenMpConfig,
)
from .kernel_decorator import kernel, kernel_config
from .kernelcreation import create_kernel
from .backend.kernelfunction import KernelFunction
......@@ -34,10 +40,12 @@ __all__ = [
"fields",
"DEFAULTS",
"TypedSymbol",
"create_type",
"make_slice",
"CreateKernelConfig",
"CpuOptimConfig",
"VectorizationConfig",
"OpenMpConfig",
"create_kernel",
"KernelFunction",
"Target",
......
......@@ -5,7 +5,8 @@ from .structural import (
PsAssignment,
PsAstNode,
PsBlock,
PsComment,
PsEmptyLeafMixIn,
PsConditional,
PsDeclaration,
PsExpression,
PsLoop,
......@@ -56,7 +57,13 @@ class UndefinedSymbolsCollector:
undefined_vars.discard(ctr.symbol)
return undefined_vars
case PsComment():
case PsConditional(cond, branch_true, branch_false):
undefined_vars = self(cond) | self(branch_true)
if branch_false is not None:
undefined_vars |= self(branch_false)
return undefined_vars
case PsEmptyLeafMixIn():
return set()
case unknown:
......@@ -85,10 +92,11 @@ class UndefinedSymbolsCollector:
case (
PsAssignment()
| PsBlock()
| PsComment()
| PsConditional()
| PsExpression()
| PsLoop()
| PsStatement()
| PsEmptyLeafMixIn()
):
return set()
......
......@@ -158,7 +158,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression):
def __repr__(self) -> str:
return f"PsConstantExpr({repr(self._constant)})"
class PsLiteralExpr(PsLeafMixIn, PsExpression):
__match_args__ = ("literal",)
......@@ -177,7 +177,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression):
def clone(self) -> PsLiteralExpr:
return PsLiteralExpr(self._literal)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsLiteralExpr):
return False
......@@ -419,6 +419,58 @@ class PsCall(PsExpression):
if not isinstance(other, PsCall):
return False
return super().structurally_equal(other) and self._function == other._function
def __str__(self):
args = ", ".join(str(arg) for arg in self._args)
return f"PsCall({self._function}, ({args}))"
class PsTernary(PsExpression):
"""Ternary operator."""
__match_args__ = ("condition", "case_then", "case_else")
def __init__(
self, cond: PsExpression, then: PsExpression, els: PsExpression
) -> None:
super().__init__()
self._cond = cond
self._then = then
self._else = els
@property
def condition(self) -> PsExpression:
return self._cond
@property
def case_then(self) -> PsExpression:
return self._then
@property
def case_else(self) -> PsExpression:
return self._else
def clone(self) -> PsExpression:
return PsTernary(self._cond.clone(), self._then.clone(), self._else.clone())
def get_children(self) -> tuple[PsExpression, ...]:
return (self._cond, self._then, self._else)
def set_child(self, idx: int, c: PsAstNode):
idx = range(3)[idx]
match idx:
case 0:
self._cond = failing_cast(PsExpression, c)
case 1:
self._then = failing_cast(PsExpression, c)
case 2:
self._else = failing_cast(PsExpression, c)
def __str__(self) -> str:
return f"PsTernary({self._cond}, {self._then}, {self._else})"
def __repr__(self) -> str:
return f"PsTernary({repr(self._cond)}, {repr(self._then)}, {repr(self._else)})"
class PsNumericOpTrait:
......@@ -582,9 +634,21 @@ class PsDiv(PsBinOp, PsNumericOpTrait):
class PsIntDiv(PsBinOp, PsIntOpTrait):
"""C-like integer division (round to zero)."""
# python_operator not implemented because both floordiv and truediv have
# different semantics.
pass
@property
def python_operator(self) -> Callable[[Any, Any], Any]:
from ...utils import c_intdiv
return c_intdiv
class PsRem(PsBinOp, PsIntOpTrait):
"""C-style integer division remainder"""
@property
def python_operator(self) -> Callable[[Any, Any], Any]:
from ...utils import c_rem
return c_rem
class PsLeftShift(PsBinOp, PsIntOpTrait):
......
......@@ -4,32 +4,32 @@ from .structural import PsAstNode
def dfs_preorder(
node: PsAstNode, yield_pred: Callable[[PsAstNode], bool] = lambda _: True
node: PsAstNode, filter_pred: Callable[[PsAstNode], bool] = lambda _: True
) -> Generator[PsAstNode, None, None]:
"""Pre-Order depth-first traversal of an abstract syntax tree.
Args:
node: The tree's root node
yield_pred: Filter predicate; a node is only yielded to the caller if `yield_pred(node)` returns True
filter_pred: Filter predicate; a node is only returned to the caller if `yield_pred(node)` returns True
"""
if yield_pred(node):
if filter_pred(node):
yield node
for c in node.children:
yield from dfs_preorder(c, yield_pred)
yield from dfs_preorder(c, filter_pred)
def dfs_postorder(
node: PsAstNode, yield_pred: Callable[[PsAstNode], bool] = lambda _: True
node: PsAstNode, filter_pred: Callable[[PsAstNode], bool] = lambda _: True
) -> Generator[PsAstNode, None, None]:
"""Post-Order depth-first traversal of an abstract syntax tree.
Args:
node: The tree's root node
yield_pred: Filter predicate; a node is only yielded to the caller if `yield_pred(node)` returns True
filter_pred: Filter predicate; a node is only returned to the caller if `yield_pred(node)` returns True
"""
for c in node.children:
yield from dfs_postorder(c, yield_pred)
yield from dfs_postorder(c, filter_pred)
if yield_pred(node):
if filter_pred(node):
yield node
......@@ -307,7 +307,42 @@ class PsConditional(PsAstNode):
assert False, "unreachable code"
class PsComment(PsLeafMixIn, PsAstNode):
class PsEmptyLeafMixIn:
"""Mix-in marking AST leaves that can be treated as empty by the code generator,
such as comments and preprocessor directives."""
pass
class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
"""A C/C++ preprocessor pragma.
Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``.
Args:
text: The pragma's text, without the ``#pragma ``.
"""
__match_args__ = ("text",)
def __init__(self, text: str) -> None:
self._text = text
@property
def text(self) -> str:
return self._text
def clone(self) -> PsPragma:
return PsPragma(self.text)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsPragma):
return False
return self._text == other._text
class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
__match_args__ = ("lines",)
def __init__(self, text: str) -> None:
......
......@@ -10,6 +10,7 @@ from .ast.structural import (
PsLoop,
PsConditional,
PsComment,
PsPragma,
)
from .ast.expressions import (
......@@ -25,6 +26,7 @@ from .ast.expressions import (
PsConstantExpr,
PsDeref,
PsDiv,
PsRem,
PsIntDiv,
PsLeftShift,
PsLookup,
......@@ -36,6 +38,7 @@ from .ast.expressions import (
PsSymbolExpr,
PsLiteralExpr,
PsVectorArrayAccess,
PsTernary,
PsAnd,
PsOr,
PsNot,
......@@ -111,6 +114,8 @@ class Ops(Enum):
LogicOr = (15, LR.Left)
Ternary = (16, LR.Right)
Weakest = (17, LR.Middle)
def __init__(self, pred: int, assoc: LR) -> None:
......@@ -235,6 +240,9 @@ class CAstPrinter:
lines_list[-1] = lines_list[-1] + " */"
return pc.indent("\n".join(lines_list))
case PsPragma(text):
return pc.indent("#pragma " + text)
case PsSymbolExpr(symbol):
return symbol.name
......@@ -246,7 +254,7 @@ class CAstPrinter:
)
return dtype.create_literal(constant.value)
case PsLiteralExpr(lit):
return lit.text
......@@ -325,6 +333,19 @@ class CAstPrinter:
type_str = target_type.c_string()
return pc.parenthesize(f"({type_str}) {operand_code}", Ops.Cast)
case PsTernary(cond, then, els):
pc.push_op(Ops.Ternary, LR.Left)
cond_code = self.visit(cond, pc)
pc.switch_branch(LR.Middle)
then_code = self.visit(then, pc)
pc.switch_branch(LR.Right)
else_code = self.visit(els, pc)
pc.pop_op()
return pc.parenthesize(
f"{cond_code} ? {then_code} : {else_code}", Ops.Ternary
)
case PsArrayInitList(items):
pc.push_op(Ops.Weakest, LR.Middle)
items_str = ", ".join(self.visit(item, pc) for item in items)
......@@ -358,6 +379,8 @@ class CAstPrinter:
return ("*", Ops.Mul)
case PsDiv() | PsIntDiv():
return ("/", Ops.Div)
case PsRem():
return ("%", Ops.Rem)
case PsLeftShift():
return ("<<", Ops.LeftShift)
case PsRightShift():
......
......@@ -33,16 +33,25 @@ class MathFunctions(Enum):
"""
Exp = ("exp", 1)
Log = ("log", 1)
Sin = ("sin", 1)
Cos = ("cos", 1)
Tan = ("tan", 1)
Sinh = ("sinh", 1)
Cosh = ("cosh", 1)
ASin = ("asin", 1)
ACos = ("acos", 1)
ATan = ("atan", 1)
Abs = ("abs", 1)
Floor = ("floor", 1)
Ceil = ("ceil", 1)
Min = ("min", 2)
Max = ("max", 2)
Pow = ("pow", 2)
ATan2 = ("atan2", 2)
def __init__(self, func_name, num_args):
self.function_name = func_name
......@@ -137,3 +146,15 @@ class PsMathFunction(PsFunction):
@property
def func(self) -> MathFunctions:
return self._func
def __str__(self) -> str:
return f"{self._func.function_name}"
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsMathFunction):
return False
return self._func == other._func
def __hash__(self) -> int:
return hash(self._func)
......@@ -3,10 +3,9 @@ from typing import cast
from .context import KernelCreationContext
from ..platforms import GenericCpu
from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations
from ..ast.structural import PsBlock
from ...config import CpuOptimConfig
from ...config import CpuOptimConfig, OpenMpConfig
def optimize_cpu(
......@@ -16,6 +15,7 @@ def optimize_cpu(
cfg: CpuOptimConfig | None,
) -> PsBlock:
"""Carry out CPU-specific optimizations according to the given configuration."""
from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations
canonicalize = CanonicalizeSymbols(ctx, True)
kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
......@@ -32,8 +32,12 @@ def optimize_cpu(
if cfg.vectorize is not False:
raise NotImplementedError("Vectorization not implemented yet")
if cfg.openmp:
raise NotImplementedError("OpenMP not implemented yet")
if cfg.openmp is not False:
from ..transformations import AddOpenMP
params = cfg.openmp if isinstance(cfg.openmp, OpenMpConfig) else OpenMpConfig()
add_omp = AddOpenMP(ctx, params)
kernel_ast = cast(PsBlock, add_omp(kernel_ast))
if cfg.use_cacheline_zeroing:
raise NotImplementedError("CL-zeroing not implemented yet")
......
......@@ -7,8 +7,14 @@ import sympy.core.relational
import sympy.logic.boolalg
from sympy.codegen.ast import AssignmentBase, AugmentedAssignment
from ...sympyextensions import Assignment, AssignmentCollection, integer_functions
from ...sympyextensions import (
Assignment,
AssignmentCollection,
integer_functions,
ConditionalFieldAccess,
)
from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc
from ...sympyextensions.pointers import AddressOf
from ...field import Field, FieldType
from .context import KernelCreationContext
......@@ -27,15 +33,18 @@ from ..ast.expressions import (
PsBitwiseAnd,
PsBitwiseOr,
PsBitwiseXor,
PsAddressOf,
PsCall,
PsCast,
PsConstantExpr,
PsIntDiv,
PsRem,
PsLeftShift,
PsLookup,
PsRightShift,
PsSubscript,
PsVectorArrayAccess,
PsTernary,
PsRel,
PsEq,
PsNe,
......@@ -349,6 +358,12 @@ class FreezeExpressions:
else:
return PsArrayAccess(ptr, index)
def map_ConditionalFieldAccess(self, acc: ConditionalFieldAccess):
facc = self.visit_expr(acc.access)
condition = self.visit_expr(acc.outofbounds_condition)
fallback = self.visit_expr(acc.outofbounds_value)
return PsTernary(condition, fallback, facc)
def map_Function(self, func: sp.Function) -> PsExpression:
"""Map SymPy function calls by mapping sympy function classes to backend-supported function symbols.
......@@ -360,16 +375,36 @@ class FreezeExpressions:
match func:
case sp.Abs():
return PsCall(PsMathFunction(MathFunctions.Abs), args)
case sp.floor():
return PsCall(PsMathFunction(MathFunctions.Floor), args)
case sp.ceiling():
return PsCall(PsMathFunction(MathFunctions.Ceil), args)
case sp.exp():
return PsCall(PsMathFunction(MathFunctions.Exp), args)
case sp.log():
return PsCall(PsMathFunction(MathFunctions.Log), args)
case sp.sin():
return PsCall(PsMathFunction(MathFunctions.Sin), args)
case sp.cos():
return PsCall(PsMathFunction(MathFunctions.Cos), args)
case sp.tan():
return PsCall(PsMathFunction(MathFunctions.Tan), args)
case sp.sinh():
return PsCall(PsMathFunction(MathFunctions.Sinh), args)
case sp.cosh():
return PsCall(PsMathFunction(MathFunctions.Cosh), args)
case sp.asin():
return PsCall(PsMathFunction(MathFunctions.ASin), args)
case sp.acos():
return PsCall(PsMathFunction(MathFunctions.ACos), args)
case sp.atan():
return PsCall(PsMathFunction(MathFunctions.ATan), args)
case sp.atan2():
return PsCall(PsMathFunction(MathFunctions.ATan2), args)
case integer_functions.int_div():
return PsIntDiv(*args)
case integer_functions.int_rem():
return PsRem(*args)
case integer_functions.bit_shift_left():
return PsLeftShift(*args)
case integer_functions.bit_shift_right():
......@@ -388,16 +423,46 @@ class FreezeExpressions:
# TODO: requires if *expression*
# case integer_functions.modulo_ceil():
# case integer_functions.div_ceil():
case AddressOf():
return PsAddressOf(*args)
case _:
raise FreezeError(f"Unsupported function: {func}")
def map_Piecewise(self, expr: sp.Piecewise) -> PsTernary:
from sympy.functions.elementary.piecewise import ExprCondPair
cases: list[ExprCondPair] = cast(list[ExprCondPair], expr.args)
if cases[-1].cond != sp.true:
raise FreezeError(
"The last case of a `Piecewise` must be the fallback case, its condition must always be `True`."
)
conditions = [self.visit_expr(c.cond) for c in cases[:-1]]
subexprs = [self.visit_expr(c.expr) for c in cases]
last_expr = subexprs.pop()
ternary = PsTernary(conditions.pop(), subexprs.pop(), last_expr)
while conditions:
ternary = PsTernary(conditions.pop(), subexprs.pop(), ternary)
return ternary
def map_Min(self, expr: sp.Min) -> PsCall:
args = tuple(self.visit_expr(arg) for arg in expr.args)
return PsCall(PsMathFunction(MathFunctions.Min), args)
return self._minmax(expr, PsMathFunction(MathFunctions.Min))
def map_Max(self, expr: sp.Max) -> PsCall:
args = tuple(self.visit_expr(arg) for arg in expr.args)
return PsCall(PsMathFunction(MathFunctions.Max), args)
return self._minmax(expr, PsMathFunction(MathFunctions.Max))
def _minmax(self, expr: sp.Min | sp.Max, func: PsMathFunction) -> PsCall:
args = [self.visit_expr(arg) for arg in expr.args]
while len(args) > 1:
args = [
(PsCall(func, (args[i], args[i + 1])) if i + 1 < len(args) else args[i])
for i in range(0, len(args), 2)
]
return cast(PsCall, args[0])
def map_CastFunc(self, cast_expr: CastFunc) -> PsCast:
return PsCast(cast_expr.dtype, self.visit_expr(cast_expr.expr))
......@@ -421,12 +486,12 @@ class FreezeExpressions:
raise FreezeError(f"Unsupported relation: {other}")
def map_And(self, conj: sympy.logic.And) -> PsAnd:
arg1, arg2 = [self.visit_expr(arg) for arg in conj.args]
return PsAnd(arg1, arg2)
args = [self.visit_expr(arg) for arg in conj.args]
return reduce(PsAnd, args) # type: ignore
def map_Or(self, disj: sympy.logic.Or) -> PsOr:
arg1, arg2 = [self.visit_expr(arg) for arg in disj.args]
return PsOr(arg1, arg2)
args = [self.visit_expr(arg) for arg in disj.args]
return reduce(PsOr, args) # type: ignore
def map_Not(self, neg: sympy.logic.Not) -> PsNot:
arg = self.visit_expr(neg.args[0])
......
......@@ -11,7 +11,7 @@ from ...field import Field, FieldType
from ..symbols import PsSymbol
from ..constants import PsConstant
from ..ast.expressions import PsExpression, PsConstantExpr
from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem
from ..arrays import PsLinearizedArray
from ..ast.util import failing_cast
from ...types import PsStructType, constify
......@@ -210,14 +210,37 @@ class FullIterationSpace(IterationSpace):
return self._archetype_field
def actual_iterations(self, dimension: int | None = None) -> PsExpression:
from .typification import Typifier
from ..transformations import EliminateConstants
typify = Typifier(self._ctx)
fold = EliminateConstants(self._ctx)
if dimension is None:
return reduce(
mul, (self.actual_iterations(d) for d in range(len(self.dimensions)))
return fold(
typify(
reduce(
mul,
(
self.actual_iterations(d)
for d in range(len(self.dimensions))
),
)
)
)
else:
dim = self.dimensions[dimension]
one = PsConstantExpr(PsConstant(1, self._ctx.index_dtype))
return one + (dim.stop - dim.start - one) / dim.step
zero = PsConstantExpr(PsConstant(0, self._ctx.index_dtype))
return fold(
typify(
PsTernary(
PsEq(PsRem((dim.stop - dim.start), dim.step), zero),
(dim.stop - dim.start) / dim.step,
(dim.stop - dim.start) / dim.step + one,
)
)
)
def compressed_counter(self) -> PsExpression:
"""Expression counting the actual number of items processed at the iteration defined by the counter tuple.
......
......@@ -13,6 +13,7 @@ from ...types import (
PsPointerType,
PsBoolType,
constify,
deconstify,
)
from ..ast.structural import (
PsAstNode,
......@@ -22,7 +23,7 @@ from ..ast.structural import (
PsExpression,
PsAssignment,
PsDeclaration,
PsComment,
PsEmptyLeafMixIn,
)
from ..ast.expressions import (
PsArrayAccess,
......@@ -32,6 +33,7 @@ from ..ast.expressions import (
PsNumericOpTrait,
PsBoolOpTrait,
PsCall,
PsTernary,
PsCast,
PsDeref,
PsAddressOf,
......@@ -159,7 +161,7 @@ class TypeContext:
f" Constant type: {c.dtype}\n"
f" Target type: {self._target_type}"
)
case PsLiteralExpr(lit):
if not self._compatible(lit.dtype):
raise TypificationError(
......@@ -336,7 +338,7 @@ class Typifier:
self.visit(body)
case PsComment():
case PsEmptyLeafMixIn():
pass
case _:
......@@ -412,6 +414,11 @@ class Typifier:
tc.apply_dtype(ptr_tc.target_type.base_type, expr)
case PsAddressOf(arg):
if not isinstance(arg, (PsSymbolExpr, PsSubscript, PsDeref, PsLookup)):
raise TypificationError(
f"Illegal expression below AddressOf operator: {arg}"
)
arg_tc = TypeContext()
self.visit_expr(arg, arg_tc)
......@@ -420,7 +427,29 @@ class Typifier:
f"Unable to determine type of argument to AddressOf: {arg}"
)
ptr_type = PsPointerType(arg_tc.target_type, const=True)
# Inherit pointed-to type from referenced object, not from the subexpression
match arg:
case PsSymbolExpr(s):
pointed_to_type = s.get_dtype()
case PsSubscript(arr, _) | PsDeref(arr):
arr_type = arr.get_dtype()
assert isinstance(arr_type, PsDereferencableType)
pointed_to_type = arr_type.base_type
case PsLookup(aggr, member_name):
struct_type = aggr.get_dtype()
assert isinstance(struct_type, PsStructType)
if struct_type.const:
pointed_to_type = constify(
struct_type.get_member(member_name).dtype
)
else:
pointed_to_type = deconstify(
struct_type.get_member(member_name).dtype
)
case _:
assert False, "unreachable code"
ptr_type = PsPointerType(pointed_to_type, const=True)
tc.apply_dtype(ptr_type, expr)
case PsLookup(aggr, member_name):
......@@ -437,7 +466,7 @@ class Typifier:
member = aggr_type.find_member(member_name)
if member is None:
raise TypificationError(
f"Aggregate of type {aggr_type} does not have a member {member}."
f"Aggregate of type {aggr_type} does not have a member {member_name}."
)
member_type = member.dtype
......@@ -446,6 +475,14 @@ class Typifier:
tc.apply_dtype(member_type, expr)
case PsTernary(cond, then, els):
cond_tc = TypeContext(target_type=PsBoolType())
self.visit_expr(cond, cond_tc)
self.visit_expr(then, tc)
self.visit_expr(els, tc)
tc.infer_dtype(expr)
case PsRel(op1, op2):
args_tc = TypeContext()
self.visit_expr(op1, args_tc)
......