Skip to content
Snippets Groups Projects
test_typification.py 11.34 KiB
import pytest
import sympy as sp
import numpy as np

from typing import cast

from pystencils import Assignment, TypedSymbol, Field, FieldType

from pystencils.backend.ast.structural import (
    PsDeclaration,
    PsAssignment,
    PsExpression,
    PsConditional,
    PsBlock,
)
from pystencils.backend.ast.expressions import (
    PsConstantExpr,
    PsSymbolExpr,
    PsBinOp,
    PsAnd,
    PsOr,
    PsNot,
    PsEq,
    PsNe,
    PsGe,
    PsLe,
    PsGt,
    PsLt,
    PsCall,
    PsTernary
)
from pystencils.backend.constants import PsConstant
from pystencils.backend.functions import CFunction
from pystencils.types import constify, create_type, create_numeric_type
from pystencils.types.quick import Fp, Int, Bool
from pystencils.backend.kernelcreation.context import KernelCreationContext
from pystencils.backend.kernelcreation.freeze import FreezeExpressions
from pystencils.backend.kernelcreation.typification import Typifier, TypificationError

from pystencils.sympyextensions.integer_functions import (
    bit_shift_left,
    bit_shift_right,
    bitwise_and,
    bitwise_xor,
    bitwise_or,
)


def test_typify_simple():
    ctx = KernelCreationContext()
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x, y, z = sp.symbols("x, y, z")
    asm = Assignment(z, 2 * x + y)

    fasm = freeze(asm)
    fasm = typify(fasm)

    assert isinstance(fasm, PsDeclaration)

    def check(expr):
        assert expr.dtype == constify(ctx.default_dtype)
        match expr:
            case PsConstantExpr(cs):
                assert cs.value == 2
                assert cs.dtype == constify(ctx.default_dtype)
            case PsSymbolExpr(symb):
                assert symb.name in "xyz"
                assert symb.dtype == ctx.default_dtype
            case PsBinOp(op1, op2):
                check(op1)
                check(op2)
            case _:
                pytest.fail(f"Unexpected expression: {expr}")

    check(fasm.lhs)
    check(fasm.rhs)


def test_rhs_constness():
    default_type = Fp(32)
    ctx = KernelCreationContext(default_dtype=default_type)

    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    f = Field.create_generic(
        "f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM
    )
    f_const = Field.create_generic(
        "f_const",
        1,
        index_shape=(1,),
        dtype=constify(default_type),
        field_type=FieldType.CUSTOM,
    )

    x, y, z = sp.symbols("x, y, z")

    #   Right-hand sides should always get const types
    asm = typify(freeze(Assignment(x, f.absolute_access([0], [0]))))
    assert asm.rhs.get_dtype().const

    asm = typify(
        freeze(
            Assignment(
                f.absolute_access([0], [0]),
                f.absolute_access([0], [0]) * f_const.absolute_access([0], [0]) * x + y,
            )
        )
    )
    assert asm.rhs.get_dtype().const


def test_lhs_constness():
    default_type = Fp(32)
    ctx = KernelCreationContext(default_dtype=default_type)
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    f = Field.create_generic(
        "f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM
    )
    f_const = Field.create_generic(
        "f_const",
        1,
        index_shape=(1,),
        dtype=constify(default_type),
        field_type=FieldType.CUSTOM,
    )

    x, y, z = sp.symbols("x, y, z")

    #   Assignment RHS may not be const
    asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y)))
    assert not asm.lhs.get_dtype().const

    #   Cannot assign to const left-hand side
    with pytest.raises(TypificationError):
        _ = typify(freeze(Assignment(f_const.absolute_access([0], [0]), x + y)))

    np_struct = np.dtype([("size", np.uint32), ("data", np.float32)])
    struct_type = constify(create_type(np_struct))
    struct_field = Field.create_generic(
        "struct_field", 1, dtype=struct_type, field_type=FieldType.CUSTOM
    )

    with pytest.raises(TypificationError):
        _ = typify(freeze(Assignment(struct_field.absolute_access([0], "data"), x)))

    #   Const LHS is only OK in declarations

    q = ctx.get_symbol("q", Fp(32, const=True))
    ast = PsDeclaration(PsExpression.make(q), PsExpression.make(q))
    ast = typify(ast)
    assert ast.lhs.dtype == Fp(32, const=True)

    ast = PsAssignment(PsExpression.make(q), PsExpression.make(q))
    with pytest.raises(TypificationError):
        typify(ast)


def test_typify_structs():
    ctx = KernelCreationContext(default_dtype=Fp(32))
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    np_struct = np.dtype([("size", np.uint32), ("data", np.float32)])
    f = Field.create_generic("f", 1, dtype=np_struct, field_type=FieldType.CUSTOM)
    x = sp.Symbol("x")

    #   Good
    asm = Assignment(x, f.absolute_access((0,), "data"))
    fasm = freeze(asm)
    fasm = typify(fasm)

    asm = Assignment(f.absolute_access((0,), "data"), x)
    fasm = freeze(asm)
    fasm = typify(fasm)

    #   Bad
    asm = Assignment(x, f.absolute_access((0,), "size"))
    fasm = freeze(asm)
    with pytest.raises(TypificationError):
        fasm = typify(fasm)


def test_contextual_typing():
    ctx = KernelCreationContext()
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x, y, z = sp.symbols("x, y, z")
    expr = freeze(2 * x + 3 * y + z - 4)
    expr = typify(expr)

    def check(expr):
        assert expr.dtype == constify(ctx.default_dtype)
        match expr:
            case PsConstantExpr(cs):
                assert cs.value in (2, 3, -4)
                assert cs.dtype == constify(ctx.default_dtype)
            case PsSymbolExpr(symb):
                assert symb.name in "xyz"
                assert symb.dtype == ctx.default_dtype
            case PsBinOp(op1, op2):
                check(op1)
                check(op2)
            case _:
                pytest.fail(f"Unexpected expression: {expr}")

    check(expr)


def test_erronous_typing():
    ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x, y, z = sp.symbols("x, y, z")
    q = TypedSymbol("q", np.float32)
    w = TypedSymbol("w", np.float16)

    expr = freeze(2 * x + 3 * y + q - 4)

    with pytest.raises(TypificationError):
        typify(expr)

    asm = Assignment(q, 3 - w)
    fasm = freeze(asm)
    with pytest.raises(TypificationError):
        typify(fasm)

    asm = Assignment(q, 3 - x)
    fasm = freeze(asm)
    with pytest.raises(TypificationError):
        typify(fasm)


def test_typify_integer_binops():
    ctx = KernelCreationContext()
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    ctx.get_symbol("x", ctx.index_dtype)
    ctx.get_symbol("y", ctx.index_dtype)

    x, y = sp.symbols("x, y")
    expr = bit_shift_left(
        bit_shift_right(bitwise_and(2, 2), bitwise_or(x, y)), bitwise_xor(2, 2)
    )
    expr = freeze(expr)
    expr = typify(expr)

    def check(expr):
        match expr:
            case PsConstantExpr(cs):
                assert cs.value == 2
                assert cs.dtype == constify(ctx.index_dtype)
            case PsSymbolExpr(symb):
                assert symb.name in "xyz"
                assert symb.dtype == ctx.index_dtype
            case PsBinOp(op1, op2):
                check(op1)
                check(op2)
            case _:
                pytest.fail(f"Unexpected expression: {expr}")

    check(expr)


def test_typify_integer_binops_floating_arg():
    ctx = KernelCreationContext()
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x = sp.Symbol("x")
    expr = bit_shift_left(x, 2)
    expr = freeze(expr)

    with pytest.raises(TypificationError):
        expr = typify(expr)


def test_typify_integer_binops_in_floating_context():
    ctx = KernelCreationContext()
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    ctx.get_symbol("i", ctx.index_dtype)

    x, i = sp.symbols("x, i")
    expr = x + bit_shift_left(i, 2)
    expr = freeze(expr)

    with pytest.raises(TypificationError):
        expr = typify(expr)


def test_typify_constant_clones():
    ctx = KernelCreationContext(default_dtype=Fp(32))
    typify = Typifier(ctx)

    c = PsConstantExpr(PsConstant(3.0))
    x = PsSymbolExpr(ctx.get_symbol("x"))
    expr = c + x
    expr_clone = expr.clone()

    expr = typify(expr)

    assert expr_clone.operand1.dtype is None
    assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None


def test_typify_bools_and_relations():
    ctx = KernelCreationContext()
    typify = Typifier(ctx)

    true = PsConstantExpr(PsConstant(True, Bool()))
    p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]
    x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]

    expr = PsAnd(PsEq(x, y), PsAnd(true, PsNot(PsOr(p, q))))
    expr = typify(expr)

    assert expr.dtype == Bool(const=True)


def test_bool_in_numerical_context():
    ctx = KernelCreationContext()
    typify = Typifier(ctx)

    true = PsConstantExpr(PsConstant(True, Bool()))
    p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]

    expr = true + (p - q)
    with pytest.raises(TypificationError):
        typify(expr)


@pytest.mark.parametrize("rel", [PsEq, PsNe, PsLt, PsGt, PsLe, PsGe])
def test_typify_conditionals(rel):
    ctx = KernelCreationContext()
    typify = Typifier(ctx)

    x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]

    cond = PsConditional(rel(x, y), PsBlock([]))
    cond = typify(cond)
    assert cond.condition.dtype == Bool(const=True)


def test_invalid_conditions():
    ctx = KernelCreationContext()
    typify = Typifier(ctx)

    x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
    p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]

    cond = PsConditional(x + y, PsBlock([]))
    with pytest.raises(TypificationError):
        typify(cond)

    cond = PsConditional(PsAnd(p, PsOr(x, q)), PsBlock([]))
    with pytest.raises(TypificationError):
        typify(cond)

    
def test_typify_ternary():
    ctx = KernelCreationContext()
    typify = Typifier(ctx)

    x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
    a, b = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "ab"]
    p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]

    expr = PsTernary(p, x, y)
    expr = typify(expr)
    assert expr.dtype == Fp(32, const=True)

    expr = PsTernary(PsAnd(p, q), a, b + a)
    expr = typify(expr)
    assert expr.dtype == Int(32, const=True)

    expr = PsTernary(PsAnd(p, q), a, x)
    with pytest.raises(TypificationError):
        typify(expr)

    expr = PsTernary(y, a, b)
    with pytest.raises(TypificationError):
        typify(expr)


def test_cfunction():
    ctx = KernelCreationContext()
    typify = Typifier(ctx)
    x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
    p, q = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "pq"]

    def _threeway(x: np.float32, y: np.float32) -> np.int32:
        assert False

    threeway = CFunction.parse(_threeway)

    result = typify(PsCall(threeway, [x, y]))

    assert result.get_dtype() == Int(32, const=True)
    assert result.args[0].get_dtype() == Fp(32, const=True)
    assert result.args[1].get_dtype() == Fp(32, const=True)

    with pytest.raises(TypificationError):
        _ = typify(PsCall(threeway, (x, p)))