Skip to content
Snippets Groups Projects
Commit a898a3ba authored by Frederik Hennig's avatar Frederik Hennig
Browse files

config descriptors

parent 25a6786a
No related branches found
No related tags found
1 merge request!442Refactor Configuration API
......@@ -2,10 +2,12 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from warnings import warn
from abc import ABC
from collections.abc import Collection
from copy import copy
from typing import Sequence
from dataclasses import dataclass, InitVar, replace
from typing import Sequence, Generic, TypeVar, Callable, Any, cast
from dataclasses import dataclass, InitVar, replace, fields
from .target import Target
from ..field import Field, FieldType
......@@ -28,6 +30,131 @@ class PsOptionsError(Exception):
"""Indicates an option clash in the `CreateKernelConfig`."""
Option_T = TypeVar("Option_T")
Arg_T = TypeVar("Arg_T")
class Option(Generic[Option_T, Arg_T]):
"""Option descriptor.
This descriptor is used to model configuration options.
It maintains a default value for the option that is used when no value
was specified by the user.
In configuration options, the value `None` stands for `unset`.
It can therefore not be used to set an option to the meaning "not any", or "empty"
- for these, special values need to be used.
The Option allows a validator function to be specified,
which will be called to perform sanity checks on user-provided values.
Through the validator, options may also be set from arguments of a different type (`Arg_T`)
than their value type (`Option_T`). If `Arg_T` is different from `Option_T`,
the validator must perform the conversion from the former to the latter.
"""
def __init__(
self,
default: Option_T | None = None,
validator: Callable[[Any, Arg_T | None], Option_T | None] | None = None,
) -> None:
self._default = default
self._validator = validator
self._name: str
self._lookup: str
def validate(self, validator: Callable[[Any, Any], Any] | None):
self._validator = validator
return validator
@property
def default(self) -> Option_T | None:
return self._default
def get(self, obj) -> Option_T | None:
val = getattr(obj, self._lookup, None)
if val is None:
return self._default
else:
return val
def is_set(self, obj) -> bool:
return getattr(obj, self._lookup, None) is not None
def __set_name__(self, owner, name: str):
self._name = name
self._lookup = f"_{name}"
def __get__(self, obj, objtype=None) -> Option_T | None:
if obj is None:
return None
return getattr(obj, self._lookup, None)
def __set__(self, obj, arg: Arg_T | None):
if arg is not None and self._validator is not None:
value = self._validator(obj, arg)
else:
value = cast(Option_T, arg)
setattr(obj, self._lookup, value)
def __delete__(self, obj):
delattr(obj, self._lookup)
class SimpleOption(Option[Option_T, Option_T]):
...
class ConfigBase(ABC):
def get_option(self, name: str) -> Any:
"""Get the value set for the specified option, or the option's default value if none has been set."""
descr: Option = type(self).__dict__[name]
return descr.get(self)
def is_option_set(self, name: str) -> bool:
descr: Option = type(self).__dict__[name]
return descr.is_set(self)
def override(self, other: ConfigBase):
for f in fields(self): # type: ignore
fvalue = getattr(self, f.name)
if isinstance(fvalue, ConfigBase): # type: ignore
fvalue.override(getattr(other, f.name))
else:
new_val = getattr(other, f.name)
if new_val is not None:
setattr(self, f.name, new_val)
Category_T = TypeVar("Category_T", bound=ConfigBase)
class Category(Generic[Category_T]):
"""Descriptor for a category of options.
This descriptor makes sure that when an entire category is set to an object,
that object is copied immediately such that later changes to the original
do not affect this configuration.
"""
def __init__(self, default: Category_T):
self._default = default
def __set_name__(self, owner, name: str):
self._name = name
self._lookup = f"_{name}"
def __get__(self, obj, objtype=None) -> Category_T:
if obj is None:
return self._default
return cast(Category_T, getattr(obj, self._lookup, None))
def __set__(self, obj, cat: Category_T):
setattr(obj, self._lookup, copy(cat))
class _AUTO_TYPE: ... # noqa: E701
......
from dataclasses import dataclass
from pystencils.codegen.config import SimpleOption, Option, Category, ConfigBase
def test_descriptors():
@dataclass
class SampleCategory(ConfigBase):
val1: SimpleOption[int] = SimpleOption(2)
val2: Option[bool, str] = Option(False)
@val2.validate
def _val2(self, v: str):
if v.lower() in ("off", "false", "no"):
return False
elif v.lower() in ("on", "true", "yes"):
return True
raise ValueError()
@dataclass
class SampleConfig(ConfigBase):
cat: Category[SampleCategory] = Category(SampleCategory())
val: SimpleOption[str] = SimpleOption("fallback")
cfg = SampleConfig()
# Check unset and default values
assert cfg.val is None
assert cfg.get_option("val") == "fallback"
# Check setting
cfg.val = "test"
assert cfg.val == "test"
assert cfg.get_option("val") == "test"
assert cfg.is_option_set("val")
# Check unsetting
cfg.val = None
assert not cfg.is_option_set("val")
assert cfg.val is None
# Check category
assert cfg.cat.val1 is None
assert cfg.cat.get_option("val1") == 2
assert cfg.cat.val2 is None
assert cfg.cat.get_option("val2") is False
# Check copy on category setting
c = SampleCategory(32, "on")
cfg.cat = c
assert cfg.cat.val1 == 32
assert cfg.cat.val2 is True
assert cfg.cat is not c
c.val1 = 13
assert cfg.cat.val1 == 32
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment