Skip to content
Snippets Groups Projects
test_constants.py 2.03 KiB
import numpy as np
import pytest

from pystencils.types import PsTypeError
from pystencils.backend.constants import PsConstant
from pystencils.types.quick import Fp, Bool, UInt, SInt
from pystencils.backend.exceptions import PsInternalCompilerError


def test_constant_equality():
    c1 = PsConstant(1.0, Fp(32))
    c2 = PsConstant(1.0, Fp(32))

    assert c1 == c2
    assert hash(c1) == hash(c2)

    c3 = PsConstant(1.0, Fp(64))
    assert c1 != c3
    assert hash(c1) != hash(c3)

    c4 = c1.reinterpret_as(Fp(64))
    assert c4 != c1
    assert c4 == c3


def test_interpret():
    c1 = PsConstant(3.4, Fp(32))
    c2 = PsConstant(3.4)

    assert c2.interpret_as(Fp(32)) == c1

    with pytest.raises(PsInternalCompilerError):
        _ = c1.interpret_as(Fp(64))


def test_boolean_constants():
    true = PsConstant(True, Bool())
    for val in (1, 1.0, True, np.True_):
        assert PsConstant(val, Bool()) == true

    false = PsConstant(False, Bool())
    for val in (0, 0.0, False, np.False_):
        assert PsConstant(val, Bool()) == false

    with pytest.raises(PsTypeError):
        PsConstant(1.1, Bool())


def test_integer_bounds():
    #  should not throw:
    for val in (255, np.uint8(255), np.int16(255), np.int64(255)):
        _ = PsConstant(val, UInt(8))

    for val in (-128, np.int16(-128), np.int64(-128)):
        _ = PsConstant(val, SInt(8))
    
    #  should throw:
    for val in (256, np.int16(256), np.int64(256)):
        with pytest.raises(PsTypeError):
            _ = PsConstant(val, UInt(8))

    for val in (-42, np.int32(-42)):
        with pytest.raises(PsTypeError):
            _ = PsConstant(val, UInt(8))

    for val in (-129, np.int16(-129), np.int64(-129)):
        with pytest.raises(PsTypeError):
            _ = PsConstant(val, SInt(8))


def test_floating_bounds():
    for val in (5.1e4, -5.9e4):
        _ = PsConstant(val, Fp(16))
        _ = PsConstant(val, Fp(32))
        _ = PsConstant(val, Fp(64))

    for val in (8.1e5, -7.6e5):
        with pytest.raises(PsTypeError):
            _ = PsConstant(val, Fp(16))