Skip to content
Snippets Groups Projects
Select Git revision
  • 03bd4bf42fe0b1905a8e25e95fc6042ff993fc1e
  • master default protected
  • v2.0-dev protected
  • zikeliml/Task-96-dotExporterForAST
  • zikeliml/124-rework-tutorials
  • fma
  • fhennig/v2.0-deprecations
  • holzer-master-patch-46757
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • gpu_bufferfield_fix
  • hyteg
  • vectorization_sqrt_fix
  • target_dh_refactoring
  • const_fix
  • improved_comm
  • gpu_liveness_opts
  • release/1.3.7 protected
  • release/1.3.6 protected
  • release/2.0.dev0 protected
  • release/1.3.5 protected
  • release/1.3.4 protected
  • release/1.3.3 protected
  • release/1.3.2 protected
  • release/1.3.1 protected
  • release/1.3 protected
  • release/1.2 protected
  • release/1.1.1 protected
  • release/1.1 protected
  • release/1.0.1 protected
  • release/1.0 protected
  • release/0.4.4 protected
  • last/Kerncraft
  • last/OpenCL
  • last/LLVM
  • release/0.4.3 protected
  • release/0.4.2 protected
36 results

test_typification.py

Blame
  • test_typification.py 17.23 KiB
    import pytest
    import sympy as sp
    import numpy as np
    
    from typing import cast
    
    from pystencils import Assignment, TypedSymbol, Field, FieldType, AddAugmentedAssignment
    
    from pystencils.backend.ast import dfs_preorder
    from pystencils.backend.ast.structural import (
        PsDeclaration,
        PsAssignment,
        PsExpression,
        PsConditional,
        PsBlock,
    )
    from pystencils.backend.ast.expressions import (
        PsArrayInitList,
        PsCast,
        PsConstantExpr,
        PsSymbolExpr,
        PsSubscript,
        PsBinOp,
        PsAnd,
        PsOr,
        PsNot,
        PsEq,
        PsNe,
        PsGe,
        PsLe,
        PsGt,
        PsLt,
        PsCall,
        PsTernary,
    )
    from pystencils.backend.constants import PsConstant
    from pystencils.backend.functions import CFunction
    from pystencils.types import constify, create_type, create_numeric_type
    from pystencils.types.quick import Fp, Int, Bool, Arr
    from pystencils.backend.kernelcreation.context import KernelCreationContext
    from pystencils.backend.kernelcreation.freeze import FreezeExpressions
    from pystencils.backend.kernelcreation.typification import Typifier, TypificationError
    
    from pystencils.sympyextensions.integer_functions import (
        bit_shift_left,
        bit_shift_right,
        bitwise_and,
        bitwise_xor,
        bitwise_or,
    )
    
    
    def test_typify_simple():
        ctx = KernelCreationContext()
        freeze = FreezeExpressions(ctx)
        typify = Typifier(ctx)
    
        x, y, z = sp.symbols("x, y, z")
        asm = Assignment(z, 2 * x + y)
    
        fasm = freeze(asm)
        fasm = typify(fasm)
    
        assert isinstance(fasm, PsDeclaration)
    
        def check(expr):
            assert expr.dtype == constify(ctx.default_dtype)
            match expr:
                case PsConstantExpr(cs):
                    assert cs.value == 2
                    assert cs.dtype == constify(ctx.default_dtype)
                case PsSymbolExpr(symb):
                    assert symb.name in "xyz"
                    assert symb.dtype == ctx.default_dtype
                case PsBinOp(op1, op2):
                    check(op1)
                    check(op2)
                case _:
                    pytest.fail(f"Unexpected expression: {expr}")
    
        check(fasm.lhs)
        check(fasm.rhs)
    
    
    def test_rhs_constness():
        default_type = Fp(32)
        ctx = KernelCreationContext(default_dtype=default_type)
    
        freeze = FreezeExpressions(ctx)
        typify = Typifier(ctx)
    
        f = Field.create_generic(
            "f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM
        )
        f_const = Field.create_generic(
            "f_const",
            1,
            index_shape=(1,),
            dtype=constify(default_type),
            field_type=FieldType.CUSTOM,
        )
    
        x, y, z = sp.symbols("x, y, z")
    
        #   Right-hand sides should always get const types
        asm = typify(freeze(Assignment(x, f.absolute_access([0], [0]))))
        assert asm.rhs.get_dtype().const
    
        asm = typify(
            freeze(
                Assignment(
                    f.absolute_access([0], [0]),
                    f.absolute_access([0], [0]) * f_const.absolute_access([0], [0]) * x + y,
                )
            )
        )
        assert asm.rhs.get_dtype().const
    
    
    def test_lhs_constness():
        default_type = Fp(32)
        ctx = KernelCreationContext(default_dtype=default_type)
        freeze = FreezeExpressions(ctx)
        typify = Typifier(ctx)
    
        f = Field.create_generic(
            "f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM
        )
        f_const = Field.create_generic(
            "f_const",
            1,
            index_shape=(1,),
            dtype=constify(default_type),
            field_type=FieldType.CUSTOM,
        )
    
        x, y, z = sp.symbols("x, y, z")
    
        #   Assignment RHS may not be const
        asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y)))
        assert not asm.lhs.get_dtype().const
    
        #   Cannot assign to const left-hand side
        with pytest.raises(TypificationError):
            _ = typify(freeze(Assignment(f_const.absolute_access([0], [0]), x + y)))
    
        np_struct = np.dtype([("size", np.uint32), ("data", np.float32)], align=True)
        struct_type = constify(create_type(np_struct))
        struct_field = Field.create_generic(
            "struct_field", 1, dtype=struct_type, field_type=FieldType.CUSTOM
        )
    
        with pytest.raises(TypificationError):
            _ = typify(freeze(Assignment(struct_field.absolute_access([0], "data"), x)))
    
        #   Const LHS is only OK in declarations
    
        q = ctx.get_symbol("q", Fp(32, const=True))
        ast = PsDeclaration(PsExpression.make(q), PsExpression.make(q))
        ast = typify(ast)
        assert ast.lhs.dtype == Fp(32, const=True)
    
        ast = PsAssignment(PsExpression.make(q), PsExpression.make(q))
        with pytest.raises(TypificationError):
            typify(ast)
    
    
    def test_typify_structs():
        ctx = KernelCreationContext(default_dtype=Fp(32))
        freeze = FreezeExpressions(ctx)
        typify = Typifier(ctx)
    
        np_struct = np.dtype([("size", np.uint32), ("data", np.float32)], align=True)
        f = Field.create_generic("f", 1, dtype=np_struct, field_type=FieldType.CUSTOM)
        x = sp.Symbol("x")
    
        #   Good
        asm = Assignment(x, f.absolute_access((0,), "data"))
        fasm = freeze(asm)
        fasm = typify(fasm)
    
        asm = Assignment(f.absolute_access((0,), "data"), x)
        fasm = freeze(asm)
        fasm = typify(fasm)
    
        #   Bad
        asm = Assignment(x, f.absolute_access((0,), "size"))
        fasm = freeze(asm)
        with pytest.raises(TypificationError):
            fasm = typify(fasm)
    
    
    def test_default_typing():
        ctx = KernelCreationContext()
        freeze = FreezeExpressions(ctx)
        typify = Typifier(ctx)
    
        x, y, z = sp.symbols("x, y, z")
        expr = freeze(2 * x + 3 * y + z - 4)
        expr = typify(expr)
    
        def check(expr):
            assert expr.dtype == constify(ctx.default_dtype)
            match expr:
                case PsConstantExpr(cs):
                    assert cs.value in (2, 3, -4)
                    assert cs.dtype == constify(ctx.default_dtype)
                case PsSymbolExpr(symb):
                    assert symb.name in "xyz"
                    assert symb.dtype == ctx.default_dtype
                case PsBinOp(op1, op2):
                    check(op1)
                    check(op2)
                case _:
                    pytest.fail(f"Unexpected expression: {expr}")
    
        check(expr)
    
    
    def test_nested_contexts_defaults():
        ctx = KernelCreationContext(default_dtype=Fp(16))
        freeze = FreezeExpressions(ctx)
        typify = Typifier(ctx)
    
        x, y, z = sp.symbols("x, y, z")
    
        fortytwo = PsExpression.make(PsConstant(42))
    
        expr = typify(PsEq(fortytwo, fortytwo))
        assert expr.dtype == Bool(const=True)
        assert fortytwo.dtype == Fp(16, const=True)
        assert fortytwo.constant.dtype == Fp(16, const=True)
    
        decl = freeze(Assignment(x, sp.Ge(3, 4, evaluate=False)))
        decl = typify(decl)
    
        assert ctx.get_symbol("x").dtype == Bool()
        for c in dfs_preorder(decl.rhs, lambda n: isinstance(n, PsConstantExpr)):
            assert c.dtype == Fp(16, const=True)
            assert c.constant.dtype == Fp(16, const=True)
    
        thirtyone = PsExpression.make(PsConstant(31))
        decl = PsDeclaration(freeze(y), PsCast(Fp(32), thirtyone))
        decl = typify(decl)
    
        assert ctx.get_symbol("y").dtype == Fp(32)
        assert thirtyone.dtype == Fp(16, const=True)
        assert thirtyone.constant.dtype == Fp(16, const=True)
    
    
    def test_constant_decls():
        ctx = KernelCreationContext(default_dtype=Fp(16))
        freeze = FreezeExpressions(ctx)
        typify = Typifier(ctx)
    
        x, y = sp.symbols("x, y")
    
        decl = freeze(Assignment(x, 3.0))
        decl = typify(decl)
        assert ctx.get_symbol("x").dtype == Fp(16)
        assert decl.rhs.dtype == Fp(16, const=True)
        assert decl.rhs.constant.dtype == Fp(16, const=True)
    
        decl = freeze(Assignment(y, 42))
        decl = typify(decl)
        assert ctx.get_symbol("y").dtype == Fp(16)
        assert decl.rhs.dtype == Fp(16, const=True)
        assert decl.rhs.constant.dtype == Fp(16, const=True)
    
    
    def test_constant_array_decls():
        ctx = KernelCreationContext(default_dtype=Fp(16))
        freeze = FreezeExpressions(ctx)
        typify = Typifier(ctx)
    
        x, y = sp.symbols("x, y")
    
        decl = freeze(Assignment(x, (1, 2, 3, 4)))
        decl = typify(decl)
        assert ctx.get_symbol("x").dtype == Arr(Fp(16), 4)
    
        decl = freeze(Assignment(y, ((1, 2, 3, 4), (5, 6, 7, 8))))
        decl = typify(decl)
        assert ctx.get_symbol("y").dtype == Arr(Arr(Fp(16), 4), 2)
    
    
    def test_inline_arrays_1d():
        ctx = KernelCreationContext(default_dtype=Fp(16))
        freeze = FreezeExpressions(ctx)
        typify = Typifier(ctx)
    
        x, y = sp.symbols("x, y")
        idx = TypedSymbol("idx", Int(32))
    
        arr: PsArrayInitList = cast(PsArrayInitList, freeze(sp.Tuple(1, 2, 3, 4)))
        decl = PsDeclaration(freeze(x), freeze(y) + PsSubscript(arr, freeze(idx)))
        #   The array elements should learn their type from the context, which gets it from `y`
    
        decl = typify(decl)
        assert decl.lhs.dtype == Fp(16, const=True)
        assert decl.rhs.dtype == Fp(16, const=True)
    
        assert arr.dtype == Arr(Fp(16), 4, const=True)
        for item in arr.items:
            assert item.dtype == Fp(16, const=True)
    
    
    def test_inline_arrays_3d():
        ctx = KernelCreationContext(default_dtype=Fp(16))
        freeze = FreezeExpressions(ctx)
        typify = Typifier(ctx)
    
        x, y = sp.symbols("x, y")
        idx = [TypedSymbol(f"idx_{i}", Int(32)) for i in range(3)]
    
        arr = freeze(sp.Tuple(((1, 2), (3, 4), (5, 6)), ((5, 6), (7, 8), (9, 10))))
        decl = PsDeclaration(
            freeze(x),
            freeze(y)
            + PsSubscript(
                PsSubscript(PsSubscript(arr, freeze(idx[0])), freeze(idx[1])),
                freeze(idx[2]),
            ),
        )
        #   The array elements should learn their type from the context, which gets it from `y`
    
        decl = typify(decl)
        assert decl.lhs.dtype == Fp(16, const=True)
        assert decl.rhs.dtype == Fp(16, const=True)
    
        assert arr.dtype == Arr(Arr(Arr(Fp(16), 2), 3), 2, const=True)
        for item in arr.items:
            assert item.dtype == Arr(Arr(Fp(16), 2), 3, const=True)
            for iitem in item.items:
                assert iitem.dtype == Arr(Fp(16), 2, const=True)
                for iiitem in iitem.items:
                    assert iiitem.dtype == Fp(16, const=True)
    
    
    def test_lhs_inference():
        ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
        freeze = FreezeExpressions(ctx)
        typify = Typifier(ctx)
    
        x, y, z = sp.symbols("x, y, z")
        q = TypedSymbol("q", np.float32)
        w = TypedSymbol("w", np.float16)
    
        #   Type of the LHS is propagated to untyped RHS symbols
    
        asm = Assignment(x, 3 - q)
        fasm = typify(freeze(asm))
    
        assert ctx.get_symbol("x").dtype == Fp(32)
        assert fasm.lhs.dtype == constify(Fp(32))
    
        asm = Assignment(y, 3 - w)
        fasm = typify(freeze(asm))
    
        assert ctx.get_symbol("y").dtype == Fp(16)
        assert fasm.lhs.dtype == constify(Fp(16))
    
        fasm = PsAssignment(PsExpression.make(ctx.get_symbol("z")), freeze(3 - w))
        fasm = typify(fasm)
    
        assert ctx.get_symbol("z").dtype == Fp(16)
        assert fasm.lhs.dtype == Fp(16)
    
        fasm = PsDeclaration(
            PsExpression.make(ctx.get_symbol("r")), PsLe(freeze(q), freeze(2 * q))
        )
        fasm = typify(fasm)
    
        assert ctx.get_symbol("r").dtype == Bool()
        assert fasm.lhs.dtype == constify(Bool())
        assert fasm.rhs.dtype == constify(Bool())
    
    
    def test_erronous_typing():
        ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
        freeze = FreezeExpressions(ctx)
        typify = Typifier(ctx)
    
        x, y, z = sp.symbols("x, y, z")
        q = TypedSymbol("q", np.float32)
        w = TypedSymbol("w", np.float16)
    
        expr = freeze(2 * x + 3 * y + q - 4)
    
        with pytest.raises(TypificationError):
            typify(expr)
    
        #   Conflict between LHS and RHS symbols
        asm = Assignment(q, 3 - w)
        fasm = freeze(asm)
        with pytest.raises(TypificationError):
            typify(fasm)
    
        #   Do not propagate types back from LHS symbols to RHS symbols
        asm = Assignment(q, 3 - x)
        fasm = freeze(asm)
        with pytest.raises(TypificationError):
            typify(fasm)
    
        asm = AddAugmentedAssignment(z, 3 - q)
        fasm = freeze(asm)
        with pytest.raises(TypificationError):
            typify(fasm)
    
    
    def test_invalid_indices():
        ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
        typify = Typifier(ctx)
    
        arr = PsExpression.make(ctx.get_symbol("arr", Arr(Fp(64))))
        x, y, z = [PsExpression.make(ctx.get_symbol(x)) for x in "xyz"]
    
        #   Using default-typed symbols as array indices is illegal when the default type is a float
    
        fasm = PsAssignment(PsSubscript(arr, x + y), z)
    
        with pytest.raises(TypificationError):
            typify(fasm)
    
        fasm = PsAssignment(z, PsSubscript(arr, x + y))
    
        with pytest.raises(TypificationError):
            typify(fasm)
    
    
    def test_typify_integer_binops():
        ctx = KernelCreationContext()
        freeze = FreezeExpressions(ctx)
        typify = Typifier(ctx)
    
        ctx.get_symbol("x", ctx.index_dtype)
        ctx.get_symbol("y", ctx.index_dtype)
    
        x, y = sp.symbols("x, y")
        expr = bit_shift_left(
            bit_shift_right(bitwise_and(2, 2), bitwise_or(x, y)), bitwise_xor(2, 2)
        )
        expr = freeze(expr)
        expr = typify(expr)
    
        def check(expr):
            match expr:
                case PsConstantExpr(cs):
                    assert cs.value == 2
                    assert cs.dtype == constify(ctx.index_dtype)
                case PsSymbolExpr(symb):
                    assert symb.name in "xyz"
                    assert symb.dtype == ctx.index_dtype
                case PsBinOp(op1, op2):
                    check(op1)
                    check(op2)
                case _:
                    pytest.fail(f"Unexpected expression: {expr}")
    
        check(expr)
    
    
    def test_typify_integer_binops_floating_arg():
        ctx = KernelCreationContext()
        freeze = FreezeExpressions(ctx)
        typify = Typifier(ctx)
    
        x = sp.Symbol("x")
        expr = bit_shift_left(x, 2)
        expr = freeze(expr)
    
        with pytest.raises(TypificationError):
            expr = typify(expr)
    
    
    def test_typify_integer_binops_in_floating_context():
        ctx = KernelCreationContext()
        freeze = FreezeExpressions(ctx)
        typify = Typifier(ctx)
    
        ctx.get_symbol("i", ctx.index_dtype)
    
        x, i = sp.symbols("x, i")
        expr = x + bit_shift_left(i, 2)
        expr = freeze(expr)
    
        with pytest.raises(TypificationError):
            expr = typify(expr)
    
    
    def test_typify_constant_clones():
        ctx = KernelCreationContext(default_dtype=Fp(32))
        typify = Typifier(ctx)
    
        c = PsConstantExpr(PsConstant(3.0))
        x = PsSymbolExpr(ctx.get_symbol("x"))
        expr = c + x
        expr_clone = expr.clone()
    
        expr = typify(expr)
    
        assert expr_clone.operand1.dtype is None
        assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None
    
    
    def test_typify_bools_and_relations():
        ctx = KernelCreationContext()
        typify = Typifier(ctx)
    
        true = PsConstantExpr(PsConstant(True, Bool()))
        p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]
        x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
    
        expr = PsAnd(PsEq(x, y), PsAnd(true, PsNot(PsOr(p, q))))
        expr = typify(expr)
    
        assert expr.dtype == Bool(const=True)
    
    
    def test_bool_in_numerical_context():
        ctx = KernelCreationContext()
        typify = Typifier(ctx)
    
        true = PsConstantExpr(PsConstant(True, Bool()))
        p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]
    
        expr = true + (p - q)
        with pytest.raises(TypificationError):
            typify(expr)
    
    
    @pytest.mark.parametrize("rel", [PsEq, PsNe, PsLt, PsGt, PsLe, PsGe])
    def test_typify_conditionals(rel):
        ctx = KernelCreationContext()
        typify = Typifier(ctx)
    
        x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
    
        cond = PsConditional(rel(x, y), PsBlock([]))
        cond = typify(cond)
        assert cond.condition.dtype == Bool(const=True)
    
    
    def test_invalid_conditions():
        ctx = KernelCreationContext()
        typify = Typifier(ctx)
    
        x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
        p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]
    
        cond = PsConditional(x + y, PsBlock([]))
        with pytest.raises(TypificationError):
            typify(cond)
    
        cond = PsConditional(PsAnd(p, PsOr(x, q)), PsBlock([]))
        with pytest.raises(TypificationError):
            typify(cond)
    
    
    def test_typify_ternary():
        ctx = KernelCreationContext()
        typify = Typifier(ctx)
    
        x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
        a, b = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "ab"]
        p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]
    
        expr = PsTernary(p, x, y)
        expr = typify(expr)
        assert expr.dtype == Fp(32, const=True)
    
        expr = PsTernary(PsAnd(p, q), a, b + a)
        expr = typify(expr)
        assert expr.dtype == Int(32, const=True)
    
        expr = PsTernary(PsAnd(p, q), a, x)
        with pytest.raises(TypificationError):
            typify(expr)
    
        expr = PsTernary(y, a, b)
        with pytest.raises(TypificationError):
            typify(expr)
    
    
    def test_cfunction():
        ctx = KernelCreationContext()
        typify = Typifier(ctx)
        x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
        p, q = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "pq"]
    
        def _threeway(x: np.float32, y: np.float32) -> np.int32:
            assert False
    
        threeway = CFunction.parse(_threeway)
    
        result = typify(PsCall(threeway, [x, y]))
    
        assert result.get_dtype() == Int(32, const=True)
        assert result.args[0].get_dtype() == Fp(32, const=True)
        assert result.args[1].get_dtype() == Fp(32, const=True)
    
        with pytest.raises(TypificationError):
            _ = typify(PsCall(threeway, (x, p)))