import pytest
import sympy as sp
import numpy as np

from typing import cast

from pystencils import Assignment, TypedSymbol, Field, FieldType

from pystencils.backend.ast.structural import PsDeclaration
from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp
from pystencils.backend.constants import PsConstant
from pystencils.types import constify
from pystencils.types.quick import Fp, create_numeric_type
from pystencils.backend.kernelcreation.context import KernelCreationContext
from pystencils.backend.kernelcreation.freeze import FreezeExpressions
from pystencils.backend.kernelcreation.typification import Typifier, TypificationError

from pystencils.sympyextensions.integer_functions import (
    bit_shift_left,
    bit_shift_right,
    bitwise_and,
    bitwise_xor,
    bitwise_or,
)


def test_typify_simple():
    ctx = KernelCreationContext()
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x, y, z = sp.symbols("x, y, z")
    asm = Assignment(z, 2 * x + y)

    fasm = freeze(asm)
    fasm = typify(fasm)

    assert isinstance(fasm, PsDeclaration)

    def check(expr):
        assert expr.dtype == ctx.default_dtype
        match expr:
            case PsConstantExpr(cs):
                assert cs.value == 2
                assert cs.dtype == constify(ctx.default_dtype)
            case PsSymbolExpr(symb):
                assert symb.name in "xyz"
                assert symb.dtype == ctx.default_dtype
            case PsBinOp(op1, op2):
                check(op1)
                check(op2)
            case _:
                pytest.fail(f"Unexpected expression: {expr}")

    check(fasm.lhs)
    check(fasm.rhs)


def test_typify_structs():
    ctx = KernelCreationContext(default_dtype=Fp(32))
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    np_struct = np.dtype([("size", np.uint32), ("data", np.float32)])
    f = Field.create_generic("f", 1, dtype=np_struct, field_type=FieldType.CUSTOM)
    x = sp.Symbol("x")

    #   Good
    asm = Assignment(x, f.absolute_access((0,), "data"))
    fasm = freeze(asm)
    fasm = typify(fasm)

    #   Bad
    asm = Assignment(x, f.absolute_access((0,), "size"))
    fasm = freeze(asm)
    with pytest.raises(TypificationError):
        fasm = typify(fasm)


def test_contextual_typing():
    ctx = KernelCreationContext()
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x, y, z = sp.symbols("x, y, z")
    expr = freeze(2 * x + 3 * y + z - 4)
    expr = typify(expr)

    def check(expr):
        assert expr.dtype == ctx.default_dtype
        match expr:
            case PsConstantExpr(cs):
                assert cs.value in (2, 3, -4)
                assert cs.dtype == constify(ctx.default_dtype)
            case PsSymbolExpr(symb):
                assert symb.name in "xyz"
                assert symb.dtype == ctx.default_dtype
            case PsBinOp(op1, op2):
                check(op1)
                check(op2)
            case _:
                pytest.fail(f"Unexpected expression: {expr}")

    check(expr)


def test_erronous_typing():
    ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x, y, z = sp.symbols("x, y, z")
    q = TypedSymbol("q", np.float32)
    w = TypedSymbol("w", np.float16)

    expr = freeze(2 * x + 3 * y + q - 4)

    with pytest.raises(TypificationError):
        typify(expr)

    asm = Assignment(q, 3 - w)
    fasm = freeze(asm)
    with pytest.raises(TypificationError):
        typify(fasm)

    asm = Assignment(q, 3 - x)
    fasm = freeze(asm)
    with pytest.raises(TypificationError):
        typify(fasm)


def test_typify_integer_binops():
    ctx = KernelCreationContext()
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    ctx.get_symbol("x", ctx.index_dtype)
    ctx.get_symbol("y", ctx.index_dtype)

    x, y = sp.symbols("x, y")
    expr = bit_shift_left(
        bit_shift_right(bitwise_and(2, 2), bitwise_or(x, y)), bitwise_xor(2, 2)
    )
    expr = freeze(expr)
    expr = typify(expr)

    def check(expr):
        match expr:
            case PsConstantExpr(cs):
                assert cs.value == 2
                assert cs.dtype == constify(ctx.index_dtype)
            case PsSymbolExpr(symb):
                assert symb.name in "xyz"
                assert symb.dtype == ctx.index_dtype
            case PsBinOp(op1, op2):
                check(op1)
                check(op2)
            case _:
                pytest.fail(f"Unexpected expression: {expr}")

    check(expr)


def test_typify_integer_binops_floating_arg():
    ctx = KernelCreationContext()
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    x = sp.Symbol("x")
    expr = bit_shift_left(x, 2)
    expr = freeze(expr)

    with pytest.raises(TypificationError):
        expr = typify(expr)


def test_typify_integer_binops_in_floating_context():
    ctx = KernelCreationContext()
    freeze = FreezeExpressions(ctx)
    typify = Typifier(ctx)

    ctx.get_symbol("i", ctx.index_dtype)

    x, i = sp.symbols("x, i")
    expr = x + bit_shift_left(i, 2)
    expr = freeze(expr)

    with pytest.raises(TypificationError):
        expr = typify(expr)


def test_typify_constant_clones():
    ctx = KernelCreationContext(default_dtype=Fp(32))
    typify = Typifier(ctx)

    c = PsConstantExpr(PsConstant(3.0))
    x = PsSymbolExpr(ctx.get_symbol("x"))
    expr = c + x
    expr_clone = expr.clone()

    expr = typify(expr)
    
    assert expr_clone.operand1.dtype is None
    assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None