diff --git a/src/pystencilssfg/config.py b/src/pystencilssfg/config.py index 2090ef8e5eaddaa8489d2be257fc28992353bda0..051cf8dcedbd378851ad1f6cdfb70e8131f94a31 100644 --- a/src/pystencilssfg/config.py +++ b/src/pystencilssfg/config.py @@ -2,8 +2,13 @@ from __future__ import annotations from typing import Generic, TypeVar, Callable, Any from abc import ABC -from dataclasses import dataclass, fields +from dataclasses import dataclass, fields, field from enum import Enum, auto +from os import path +from importlib import util as iutil + + +class SfgConfigException(Exception): ... # noqa: E701 Option_T = TypeVar("Option_T") @@ -23,7 +28,7 @@ class Option(Generic[Option_T]): def __init__( self, - default: Option_T, + default: Option_T | None = None, validator: Callable[[Any, Option_T | None], Option_T | None] | None = None, ) -> None: self._default = default @@ -36,11 +41,10 @@ class Option(Generic[Option_T]): return validator @property - def default(self) -> Option_T: + def default(self) -> Option_T | None: return self._default - def get(self, obj) -> Option_T: - print("get called") + def get(self, obj) -> Option_T | None: val = getattr(obj, self._lookup, None) if val is None: return self._default @@ -62,6 +66,9 @@ class Option(Generic[Option_T]): value = self._validator(obj, value) setattr(obj, self._lookup, value) + def __delete__(self, obj): + delattr(obj, self._lookup) + class ConfigBase(ABC): def get_option(self, name: str) -> Any: @@ -70,12 +77,14 @@ class ConfigBase(ABC): return descr.get(self) def override(self, other: ConfigBase): - for field in fields(self): # type: ignore - fvalue = getattr(self, field.name) + for f in fields(self): # type: ignore + fvalue = getattr(self, f.name) if isinstance(fvalue, ConfigBase): # type: ignore - fvalue.override(getattr(other, field.name)) + fvalue.override(getattr(other, f.name)) else: - setattr(self, field.name, getattr(other, field.name)) + new_val = getattr(other, f.name) + if new_val is not None: + setattr(self, f.name, new_val) @dataclass @@ -83,7 +92,7 @@ class FileExtensions(ConfigBase): header: Option[str] = Option("hpp") """File extension for generated header file.""" - impl: Option[str | None] = Option(None) + impl: Option[str] = Option() """File extension for generated implementation file.""" @header.validate @@ -113,7 +122,7 @@ class SfgOutputMode(Enum): @dataclass -class SfgCodeStyle(ConfigBase): +class CodeStyle(ConfigBase): indent_width: Option[int] = Option(2) code_style: Option[str] = Option("file") @@ -141,7 +150,7 @@ GLOBAL_NAMESPACE = _GlobalNamespace() @dataclass class SfgConfig(ConfigBase): - extensions: FileExtensions = FileExtensions() + extensions: FileExtensions = field(default_factory=FileExtensions) """File extensions of the generated files""" output_mode: Option[SfgOutputMode] = Option(SfgOutputMode.STANDALONE) @@ -151,8 +160,51 @@ class SfgConfig(ConfigBase): """The outermost namespace in the generated file. May be a valid C++ nested namespace qualifier (like ``a::b::c``) or `GLOBAL_NAMESPACE` if no outer namespace should be generated.""" - codestyle: SfgCodeStyle = SfgCodeStyle() + codestyle: CodeStyle = field(default_factory=CodeStyle) """Options governing the code style used by the code generator""" output_directory: Option[str] = Option(".") """Directory to which the generated files should be written.""" + + +def run_configuration_module(configurator_script: str) -> SfgConfig: + """Run a configuration module. + + A configuration module must define a function called `configure_sfg` + with the following signature in its global namespace: + + ```Python + def configure_sfg(cfg: SfgConfig) -> None: ... + ``` + + After importing the module, that function will be called. + It should populate the ``cfg`` object with the desired configuration + options. + + TODO: The configuration module may optionally define a function ``project_info``, + which takes zero arguments and should return an containing project-specific information. + This object will later be available through the `project_info` member of the + `SfgContext`. + """ + + cfg_modulename = path.splitext(path.split(configurator_script)[1])[0] + + cfg_spec = iutil.spec_from_file_location(cfg_modulename, configurator_script) + + if cfg_spec is None: + raise SfgConfigException( + f"Unable to load configurator script {configurator_script}", + ) + + configurator = iutil.module_from_spec(cfg_spec) + cfg_spec.loader.exec_module(configurator) + + if not hasattr(configurator, "configure_sfg"): + raise SfgConfigException( + "Project configurator does not define function `configure_sfg`.", + ) + + cfg = SfgConfig() + project_config = configurator.configure_sfg(cfg) + + return project_config diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py index aa8396c4ecae7471c333a14587435ac3f35eec3a..e7b0f5c24a71f0a41a15b559332e0d3d94817ad0 100644 --- a/src/pystencilssfg/generator.py +++ b/src/pystencilssfg/generator.py @@ -1,6 +1,3 @@ -# TODO -# mypy strict_optional=False - import sys import os from os import path diff --git a/tests/generator/test_config.py b/tests/generator/test_config.py index 1bc9d0e171ac02a35f0f04b01fb4f81825340d31..d7d38bd5020e583a2f74afe6178025665f2c7a23 100644 --- a/tests/generator/test_config.py +++ b/tests/generator/test_config.py @@ -12,6 +12,13 @@ def test_defaults(): cfg.extensions.impl = ".cu" assert cfg.extensions.get_option("impl") == "cu" + # Check that section subobjects of different config objects are independent + # -> must use default_factory to construct them, because they are mutable! + cfg.codestyle.clang_format_binary = "bogus" + + cfg2 = SfgConfig() + assert cfg2.codestyle.clang_format_binary is None + def test_override(): cfg1 = SfgConfig()