Skip to content
Snippets Groups Projects
test_typification.py 18.66 KiB
import pytest
import sympy as sp
import numpy as np

from typing import cast

from pystencils import Assignment, TypedSymbol, Field, FieldType, AddAugmentedAssignment
from pystencils.sympyextensions.pointers import mem_acc

from pystencils.backend.ast.structural import (
    PsDeclaration,
    PsAssignment,
    PsExpression,
    PsConditional,
    PsBlock,
)
from pystencils.backend.ast.expressions import (
    PsArrayInitList,
    PsCast,
    PsConstantExpr,
    PsSymbolExpr,
    PsSubscript,
    PsBinOp,
    PsAnd,
    PsOr,
    PsNot,
    PsEq,
    PsNe,
    PsGe,
    PsLe,
    PsGt,
    PsLt,
    PsCall,
    PsTernary,
    PsMemAcc
)
from pystencils.backend.ast.vector import PsVecBroadcast
from pystencils.backend.constants import PsConstant
from pystencils.backend.functions import CFunction
from pystencils.types import constify, create_type, create_numeric_type, PsVectorType
from pystencils.types.quick import Fp, Int, Bool, Arr, Ptr
from pystencils.backend.kernelcreation.context import KernelCreationContext
from pystencils.backend.kernelcreation.freeze import FreezeExpressions
from pystencils.backend.kernelcreation.typification import Typifier, TypificationError

from pystencils.sympyextensions.integer_functions import (
    bit_shift_left,
    bit_shift_right,
    bitwise_and,
    bitwise_xor,
    bitwise_or,
)


def test_typify_simple():
    ctx = KernelCreationContext()
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x, y, z = sp.symbols("x, y, z")
    asm = Assignment(z, 2 * x + y)

    fasm = freeze(asm)
    fasm = typify(fasm)

    assert isinstance(fasm, PsDeclaration)

    def check(expr):
        assert expr.dtype == ctx.default_dtype
        match expr:
            case PsConstantExpr(cs):
                assert cs.value == 2
                assert cs.dtype == constify(ctx.default_dtype)
            case PsSymbolExpr(symb):
                assert symb.name in "xyz"
                assert symb.dtype == ctx.default_dtype
            case PsBinOp(op1, op2):
                check(op1)
                check(op2)
            case _:
                pytest.fail(f"Unexpected expression: {expr}")

    check(fasm.lhs)
    check(fasm.rhs)


def test_lhs_constness():
    default_type = Fp(32)
    ctx = KernelCreationContext(default_dtype=default_type)
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    f = Field.create_generic(
        "f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM
    )
    f_const = Field.create_generic(
        "f_const",
        1,
        index_shape=(1,),
        dtype=constify(default_type),
        field_type=FieldType.CUSTOM,
    )

    x, y, z = sp.symbols("x, y, z")

    #   Can assign to non-const LHS
    asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y)))
    assert not asm.lhs.get_dtype().const

    #   Cannot assign to const left-hand side
    with pytest.raises(TypificationError):
        _ = typify(freeze(Assignment(f_const.absolute_access([0], [0]), x + y)))

    np_struct = np.dtype([("size", np.uint32), ("data", np.float32)], align=True)
    struct_type = constify(create_type(np_struct))
    struct_field = Field.create_generic(
        "struct_field", 1, dtype=struct_type, field_type=FieldType.CUSTOM
    )

    with pytest.raises(TypificationError):
        _ = typify(freeze(Assignment(struct_field.absolute_access([0], "data"), x)))

    #   Const LHS is only OK in declarations

    q = ctx.get_symbol("q", Fp(32, const=True))
    ast = PsDeclaration(PsExpression.make(q), PsExpression.make(q))
    ast = typify(ast)
    assert ast.lhs.dtype == Fp(32)

    ast = PsAssignment(PsExpression.make(q), PsExpression.make(q))
    with pytest.raises(TypificationError):
        typify(ast)


def test_typify_structs():
    ctx = KernelCreationContext(default_dtype=Fp(32))
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    np_struct = np.dtype([("size", np.uint32), ("data", np.float32)], align=True)
    f = Field.create_generic("f", 1, dtype=np_struct, field_type=FieldType.CUSTOM)
    x = sp.Symbol("x")

    #   Good
    asm = Assignment(x, f.absolute_access((0,), "data"))
    fasm = freeze(asm)
    fasm = typify(fasm)

    asm = Assignment(f.absolute_access((0,), "data"), x)
    fasm = freeze(asm)
    fasm = typify(fasm)

    #   Bad
    asm = Assignment(x, f.absolute_access((0,), "size"))
    fasm = freeze(asm)
    with pytest.raises(TypificationError):
        fasm = typify(fasm)


def test_default_typing():
    ctx = KernelCreationContext()
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x, y, z = sp.symbols("x, y, z")
    expr = freeze(2 * x + 3 * y + z - 4)
    expr = typify(expr)

    def check(expr):
        assert expr.dtype == ctx.default_dtype
        match expr:
            case PsConstantExpr(cs):
                assert cs.value in (2, 3, -4)
                assert cs.dtype == constify(ctx.default_dtype)
            case PsSymbolExpr(symb):
                assert symb.name in "xyz"
                assert symb.dtype == ctx.default_dtype
            case PsBinOp(op1, op2):
                check(op1)
                check(op2)
            case _:
                pytest.fail(f"Unexpected expression: {expr}")

    check(expr)


def test_inline_arrays_1d():
    ctx = KernelCreationContext(default_dtype=Fp(32))
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x = sp.Symbol("x")
    y = TypedSymbol("y", Fp(16))
    idx = TypedSymbol("idx", Int(32))

    arr: PsArrayInitList = cast(PsArrayInitList, freeze(sp.Tuple(1, 2, 3, 4)))
    decl = PsDeclaration(freeze(x), freeze(y) + PsSubscript(arr, (freeze(idx),)))
    #   The array elements should learn their type from the context, which gets it from `y`

    decl = typify(decl)
    assert decl.lhs.dtype == Fp(16)
    assert decl.rhs.dtype == Fp(16)

    assert arr.dtype == Arr(Fp(16), (4,))
    for item in arr.items:
        assert item.dtype == Fp(16)


def test_inline_arrays_3d():
    ctx = KernelCreationContext(default_dtype=Fp(32))
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x = sp.Symbol("x")
    y = TypedSymbol("y", Fp(16))
    idx = [TypedSymbol(f"idx_{i}", Int(32)) for i in range(3)]

    arr: PsArrayInitList = freeze(
        sp.Tuple(((1, 2), (3, 4), (5, 6)), ((5, 6), (7, 8), (9, 10)))
    )
    decl = PsDeclaration(
        freeze(x),
        freeze(y) + PsSubscript(arr, (freeze(idx[0]), freeze(idx[1]), freeze(idx[2]))),
    )
    #   The array elements should learn their type from the context, which gets it from `y`

    decl = typify(decl)
    assert decl.lhs.dtype == Fp(16)
    assert decl.rhs.dtype == Fp(16)

    assert arr.dtype == Arr(Fp(16), (2, 3, 2))
    assert arr.shape == (2, 3, 2)
    for item in arr.items:
        assert item.dtype == Fp(16)


def test_array_subscript():
    ctx = KernelCreationContext(default_dtype=Fp(16))
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    arr = sp.IndexedBase(TypedSymbol("arr", Arr(Fp(32), (16,))))
    expr = freeze(arr[3])
    expr = typify(expr)

    assert expr.dtype == Fp(32)

    arr = sp.IndexedBase(TypedSymbol("arr2", Arr(Fp(32), (7, 31))))
    expr = freeze(arr[3, 5])
    expr = typify(expr)

    assert expr.dtype == Fp(32)


def test_invalid_subscript():
    ctx = KernelCreationContext(default_dtype=Fp(16))
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    non_arr = sp.IndexedBase(TypedSymbol("non_arr", Int(64)))
    expr = freeze(non_arr[3])

    with pytest.raises(TypificationError):
        expr = typify(expr)

    wrong_shape_arr = sp.IndexedBase(
        TypedSymbol("wrong_shape_arr", Arr(Fp(32), (7, 31, 5)))
    )
    expr = freeze(wrong_shape_arr[3, 5])

    with pytest.raises(TypificationError):
        expr = typify(expr)

    #   raw pointers are not arrays, cannot enter subscript
    ptr = sp.IndexedBase(
        TypedSymbol("ptr", Ptr(Int(16)))
    )
    expr = freeze(ptr[37])
    
    with pytest.raises(TypificationError):
        expr = typify(expr)

    
def test_mem_acc():
    ctx = KernelCreationContext(default_dtype=Fp(16))
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    ptr = TypedSymbol("ptr", Ptr(Int(64)))
    idx = TypedSymbol("idx", Int(32))
    
    expr = freeze(mem_acc(ptr, idx))
    expr = typify(expr)

    assert isinstance(expr, PsMemAcc)
    assert expr.dtype == Int(64)
    assert expr.offset.dtype == Int(32)


def test_invalid_mem_acc():
    ctx = KernelCreationContext(default_dtype=Fp(16))
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    non_ptr = TypedSymbol("non_ptr", Int(64))
    idx = TypedSymbol("idx", Int(32))
    
    expr = freeze(mem_acc(non_ptr, idx))
    
    with pytest.raises(TypificationError):
        _ = typify(expr)
    
    arr = TypedSymbol("arr", Arr(Int(64), (31,)))
    idx = TypedSymbol("idx", Int(32))
    
    expr = freeze(mem_acc(arr, idx))
    
    with pytest.raises(TypificationError):
        _ = typify(expr)


def test_lhs_inference():
    ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x, y, z = sp.symbols("x, y, z")
    q = TypedSymbol("q", np.float32)
    w = TypedSymbol("w", np.float16)

    #   Type of the LHS is propagated to untyped RHS symbols

    asm = Assignment(x, 3 - q)
    fasm = typify(freeze(asm))

    assert ctx.get_symbol("x").dtype == Fp(32)
    assert fasm.lhs.dtype == Fp(32)

    asm = Assignment(y, 3 - w)
    fasm = typify(freeze(asm))

    assert ctx.get_symbol("y").dtype == Fp(16)
    assert fasm.lhs.dtype == Fp(16)

    fasm = PsAssignment(PsExpression.make(ctx.get_symbol("z")), freeze(3 - w))
    fasm = typify(fasm)

    assert ctx.get_symbol("z").dtype == Fp(16)
    assert fasm.lhs.dtype == Fp(16)

    fasm = PsDeclaration(
        PsExpression.make(ctx.get_symbol("r")), PsLe(freeze(q), freeze(2 * q))
    )
    fasm = typify(fasm)

    assert ctx.get_symbol("r").dtype == Bool()
    assert fasm.lhs.dtype == Bool()
    assert fasm.rhs.dtype == Bool()


def test_array_declarations():
    ctx = KernelCreationContext(default_dtype=Fp(32))
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x, y, z = sp.symbols("x, y, z")
    
    #   Array type fallback to default
    arr1 = sp.Symbol("arr1")
    decl = freeze(Assignment(arr1, sp.Tuple(1, 2, 3, 4)))
    decl = typify(decl)

    assert ctx.get_symbol("arr1").dtype == Arr(Fp(32), (4,))
    assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(32), (4,))

    #   Array type determined by default-typed symbol
    arr2 = sp.Symbol("arr2")
    decl = freeze(Assignment(arr2, sp.Tuple((x, y, -7), (3, -2, 51))))
    decl = typify(decl)

    assert ctx.get_symbol("arr2").dtype == Arr(Fp(32), (2, 3))
    assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(32), (2, 3))

    #   Array type determined by pre-typed symbol
    q = TypedSymbol("q", Fp(16))
    arr3 = sp.Symbol("arr3")
    decl = freeze(Assignment(arr3, sp.Tuple((q, 2), (-q, 0.123))))
    decl = typify(decl)

    assert ctx.get_symbol("arr3").dtype == Arr(Fp(16), (2, 2))
    assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(16), (2, 2))

    #   Array type determined by LHS symbol
    arr4 = TypedSymbol("arr4", Arr(Int(16), 4))
    decl = freeze(Assignment(arr4, sp.Tuple(11, 1, 4, 2)))
    decl = typify(decl)

    assert decl.lhs.dtype == decl.rhs.dtype == Arr(Int(16), 4)


def test_erronous_typing():
    ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x, y, z = sp.symbols("x, y, z")
    q = TypedSymbol("q", np.float32)
    w = TypedSymbol("w", np.float16)

    expr = freeze(2 * x + 3 * y + q - 4)

    with pytest.raises(TypificationError):
        typify(expr)

    #   Conflict between LHS and RHS symbols
    asm = Assignment(q, 3 - w)
    fasm = freeze(asm)
    with pytest.raises(TypificationError):
        typify(fasm)

    #   Do not propagate types back from LHS symbols to RHS symbols
    asm = Assignment(q, 3 - x)
    fasm = freeze(asm)
    with pytest.raises(TypificationError):
        typify(fasm)

    asm = AddAugmentedAssignment(z, 3 - q)
    fasm = freeze(asm)
    with pytest.raises(TypificationError):
        typify(fasm)


def test_invalid_indices():
    ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
    typify = Typifier(ctx)

    arr = PsExpression.make(ctx.get_symbol("arr", Arr(Fp(64), (61,))))
    x, y, z = [PsExpression.make(ctx.get_symbol(x)) for x in "xyz"]

    #   Using default-typed symbols as array indices is illegal when the default type is a float

    fasm = PsAssignment(PsSubscript(arr, (x + y,)), z)

    with pytest.raises(TypificationError):
        typify(fasm)

    fasm = PsAssignment(z, PsSubscript(arr, (x + y,)))

    with pytest.raises(TypificationError):
        typify(fasm)


def test_typify_integer_binops():
    ctx = KernelCreationContext()
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    ctx.get_symbol("x", ctx.index_dtype)
    ctx.get_symbol("y", ctx.index_dtype)

    x, y = sp.symbols("x, y")
    expr = bit_shift_left(
        bit_shift_right(bitwise_and(2, 2), bitwise_or(x, y)), bitwise_xor(2, 2)
    )
    expr = freeze(expr)
    expr = typify(expr)

    def check(expr):
        match expr:
            case PsConstantExpr(cs):
                assert cs.value == 2
                assert cs.dtype == constify(ctx.index_dtype)
            case PsSymbolExpr(symb):
                assert symb.name in "xyz"
                assert symb.dtype == ctx.index_dtype
            case PsBinOp(op1, op2):
                check(op1)
                check(op2)
            case _:
                pytest.fail(f"Unexpected expression: {expr}")

    check(expr)


def test_typify_integer_binops_floating_arg():
    ctx = KernelCreationContext()
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x = sp.Symbol("x")
    expr = bit_shift_left(x, 2)
    expr = freeze(expr)

    with pytest.raises(TypificationError):
        expr = typify(expr)


def test_typify_integer_binops_in_floating_context():
    ctx = KernelCreationContext()
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    ctx.get_symbol("i", ctx.index_dtype)

    x, i = sp.symbols("x, i")
    expr = x + bit_shift_left(i, 2)
    expr = freeze(expr)

    with pytest.raises(TypificationError):
        expr = typify(expr)


def test_typify_constant_clones():
    ctx = KernelCreationContext(default_dtype=Fp(32))
    typify = Typifier(ctx)

    c = PsConstantExpr(PsConstant(3.0))
    x = PsSymbolExpr(ctx.get_symbol("x"))
    expr = c + x
    expr_clone = expr.clone()

    expr = typify(expr)

    assert expr_clone.operand1.dtype is None
    assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None


def test_typify_bools_and_relations():
    ctx = KernelCreationContext()
    typify = Typifier(ctx)

    true = PsConstantExpr(PsConstant(True, Bool()))
    p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]
    x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]

    expr = PsAnd(PsEq(x, y), PsAnd(true, PsNot(PsOr(p, q))))
    expr = typify(expr)

    assert expr.dtype == Bool() 


def test_bool_in_numerical_context():
    ctx = KernelCreationContext()
    typify = Typifier(ctx)

    true = PsConstantExpr(PsConstant(True, Bool()))
    p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]

    expr = true + (p - q)
    with pytest.raises(TypificationError):
        typify(expr)


@pytest.mark.parametrize("rel", [PsEq, PsNe, PsLt, PsGt, PsLe, PsGe])
def test_typify_conditionals(rel):
    ctx = KernelCreationContext()
    typify = Typifier(ctx)

    x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]

    cond = PsConditional(rel(x, y), PsBlock([]))
    cond = typify(cond)
    assert cond.condition.dtype == Bool()


def test_invalid_conditions():
    ctx = KernelCreationContext()
    typify = Typifier(ctx)

    x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
    p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]

    cond = PsConditional(x + y, PsBlock([]))
    with pytest.raises(TypificationError):
        typify(cond)


def test_typify_ternary():
    ctx = KernelCreationContext()
    typify = Typifier(ctx)

    x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
    a, b = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "ab"]
    p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]

    expr = PsTernary(p, x, y)
    expr = typify(expr)
    assert expr.dtype == Fp(32)

    expr = PsTernary(PsAnd(p, q), a, b + a)
    expr = typify(expr)
    assert expr.dtype == Int(32)

    expr = PsTernary(PsAnd(p, q), a, x)
    with pytest.raises(TypificationError):
        typify(expr)

    expr = PsTernary(y, a, b)
    with pytest.raises(TypificationError):
        typify(expr)


def test_cfunction():
    ctx = KernelCreationContext()
    typify = Typifier(ctx)
    x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
    p, q = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "pq"]

    def _threeway(x: np.float32, y: np.float32) -> np.int32:
        assert False

    threeway = CFunction.parse(_threeway)

    result = typify(PsCall(threeway, [x, y]))

    assert result.get_dtype() == Int(32)
    assert result.args[0].get_dtype() == Fp(32)
    assert result.args[1].get_dtype() == Fp(32)

    with pytest.raises(TypificationError):
        _ = typify(PsCall(threeway, (x, p)))


def test_typify_integer_vectors():
    ctx = KernelCreationContext()
    typify = Typifier(ctx)

    a, b, c = [PsExpression.make(ctx.get_symbol(name, PsVectorType(Int(32), 4))) for name in "abc"]
    d, e = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "de"]

    result = typify(a + (b / c) - a * c)
    assert result.get_dtype() == PsVectorType(Int(32), 4)

    result = typify(PsVecBroadcast(4, d - e) - PsVecBroadcast(4, e / d))
    assert result.get_dtype() == PsVectorType(Int(32), 4)


def test_typify_bool_vectors():
    ctx = KernelCreationContext()
    typify = Typifier(ctx)

    x, y = [PsExpression.make(ctx.get_symbol(name, PsVectorType(Fp(32), 4))) for name in "xy"]
    p, q = [PsExpression.make(ctx.get_symbol(name, PsVectorType(Bool(), 4))) for name in "pq"]

    result = typify(PsAnd(PsOr(p, q), p))
    assert result.get_dtype() == PsVectorType(Bool(), 4)

    result = typify(PsAnd(PsLt(x, y), PsGe(y, x)))
    assert result.get_dtype() == PsVectorType(Bool(), 4)


def test_inference_fails():
    ctx = KernelCreationContext()
    typify = Typifier(ctx)

    x = PsExpression.make(PsConstant(42))

    with pytest.raises(TypificationError):
        typify(PsEq(x, x))

    with pytest.raises(TypificationError):
        typify(PsArrayInitList([x]))

    with pytest.raises(TypificationError):
        typify(PsCast(ctx.default_dtype, x))