From c0f2fb544fe086d782793508db392e1cfc0b5677 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 2 Apr 2024 10:31:05 +0200 Subject: [PATCH] Add one more test case and doc comments --- .../backend/kernelcreation/typification.py | 27 ++++++++++++++++--- .../kernelcreation/test_typification.py | 16 ++++++++--- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index b01b1d35e..6cb42fce0 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -55,6 +55,17 @@ NodeT = TypeVar("NodeT", bound=PsAstNode) class TypeContext: + """Typing context, with support for type inference and checking. + + Instances of this class are used to propagate and check data types across expression subtrees + of the AST. Each type context has: + + - A target type `target_type`, which shall be applied to all expressions it covers + - A set of restrictions on the target type: + - `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides + - Additional restrictions may be added in the future. + """ + def __init__( self, target_type: PsType | None = None, require_nonconst: bool = False ): @@ -222,11 +233,11 @@ class Typifier: **Typing Rules** - The following rules apply: + The following general rules apply: - The context's `default_dtype` is applied to all untyped symbols - - By default, all expressions receive a ``const`` type unless otherwise required - - The left-hand side of any non-declaration assignment must not be ``const`` + - By default, all expressions receive a ``const`` type unless they occur on a (non-declaration) assignment's + left-hand side **Typing of symbol expressions** @@ -304,7 +315,15 @@ class Typifier: raise NotImplementedError(f"Can't typify {node}") def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None: - """Recursive processing of expression nodes""" + """Recursive processing of expression nodes. + + This method opens, expands, and closes typing contexts according to the respective expression's + typing rules. It may add or check restrictions only when opening or closing a type context. + + The actual type inference and checking during context expansion are performed by the methods + of `TypeContext`. ``visit_expr`` tells the typing context how to handle an expression by calling + either ``apply_dtype`` or ``infer_dtype``. + """ match expr: case PsSymbolExpr(_): if expr.symbol.dtype is None: diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 5e9d7cbec..d9cc5f9ce 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -6,7 +6,7 @@ from typing import cast from pystencils import Assignment, TypedSymbol, Field, FieldType -from pystencils.backend.ast.structural import PsDeclaration +from pystencils.backend.ast.structural import PsDeclaration, PsAssignment, PsExpression from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp from pystencils.backend.constants import PsConstant from pystencils.types import constify @@ -127,6 +127,17 @@ def test_lhs_constness(): with pytest.raises(TypificationError): _ = typify(freeze(Assignment(struct_field.absolute_access([0], "data"), x))) + # Const LHS is only OK in declarations + + q = ctx.get_symbol("q", Fp(32, const=True)) + ast = PsDeclaration(PsExpression.make(q), PsExpression.make(q)) + ast = typify(ast) + assert ast.lhs.dtype == Fp(32, const=True) + + ast = PsAssignment(PsExpression.make(q), PsExpression.make(q)) + with pytest.raises(TypificationError): + typify(ast) + def test_typify_structs(): ctx = KernelCreationContext(default_dtype=Fp(32)) @@ -278,6 +289,3 @@ def test_typify_constant_clones(): assert expr_clone.operand1.dtype is None assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None - - -# test_lhs_constness() -- GitLab