Skip to content
Snippets Groups Projects
test_config.py 1.96 KiB
import pytest

from dataclasses import dataclass
import numpy as np
from pystencils.codegen.config import BasicOption, Option, Category, ConfigBase, CreateKernelConfig
from pystencils.types.quick import Int, UInt


def test_descriptors():

    @dataclass
    class SampleCategory(ConfigBase):
        val1: BasicOption[int] = BasicOption(2)
        val2: Option[bool, str] = Option(False)

        @val2.validate
        def validate_val2(self, v: str):
            if v.lower() in ("off", "false", "no"):
                return False
            elif v.lower() in ("on", "true", "yes"):
                return True
            
            raise ValueError()

    @dataclass
    class SampleConfig(ConfigBase):
        cat: Category[SampleCategory] = Category(SampleCategory())
        val: BasicOption[str] = BasicOption("fallback")

    cfg = SampleConfig()
    
    #   Check unset and default values
    assert cfg.val is None
    assert cfg.get_option("val") == "fallback"

    #   Check setting
    cfg.val = "test"
    assert cfg.val == "test"
    assert cfg.get_option("val") == "test"
    assert cfg.is_option_set("val")

    #   Check unsetting
    cfg.val = None
    assert not cfg.is_option_set("val")
    assert cfg.val is None

    #   Check category
    assert cfg.cat.val1 is None
    assert cfg.cat.get_option("val1") == 2
    assert cfg.cat.val2 is None
    assert cfg.cat.get_option("val2") is False

    #   Check copy on category setting
    c = SampleCategory(32, "on")
    cfg.cat = c
    assert cfg.cat.val1 == 32
    assert cfg.cat.val2 is True
    
    assert cfg.cat is not c
    c.val1 = 13
    assert cfg.cat.val1 == 32


def test_config_validation():
    cfg = CreateKernelConfig(index_dtype="int32")
    assert cfg.index_dtype == Int(32)
    cfg.index_dtype = np.uint64
    assert cfg.index_dtype == UInt(64)

    with pytest.raises(ValueError):
        _ = CreateKernelConfig(index_dtype=np.float32)

    with pytest.raises(ValueError):
        cfg.index_dtype = "double"