Skip to content
Snippets Groups Projects
Commit ca758480 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fixed configuration

parent 2b4b01d7
No related branches found
No related tags found
No related merge requests found
from __future__ import annotations
from typing import List, Sequence from typing import List, Sequence
from enum import Enum, auto from enum import Enum, auto
from dataclasses import dataclass, replace from dataclasses import dataclass, replace, asdict, fields
from argparse import ArgumentParser from argparse import ArgumentParser
from jinja2.filters import do_indent from jinja2.filters import do_indent
...@@ -44,13 +46,18 @@ class SfgConfiguration: ...@@ -44,13 +46,18 @@ class SfgConfiguration:
if self.header_only: if self.header_only:
raise SfgException( raise SfgException(
"Header-only code generation is not implemented yet.") "Header-only code generation is not implemented yet.")
if self.header_extension[0] == '.': if self.header_extension and self.header_extension[0] == '.':
self.header_extension = self.header_extension[1:] self.header_extension = self.header_extension[1:]
if self.source_extension[0] == '.': if self.source_extension and self.source_extension[0] == '.':
self.source_extension = self.source_extension[1:] self.source_extension = self.source_extension[1:]
def override(self, other: SfgConfiguration):
other_dict = asdict(other)
other_dict = {k: v for k, v in other_dict.items() if v is not None}
return replace(self, **other_dict)
DEFAULT_CONFIG = SfgConfiguration( DEFAULT_CONFIG = SfgConfiguration(
header_extension='h', header_extension='h',
...@@ -62,32 +69,11 @@ DEFAULT_CONFIG = SfgConfiguration( ...@@ -62,32 +69,11 @@ DEFAULT_CONFIG = SfgConfiguration(
) )
def get_file_extensions(self, extensions: Sequence[str]):
h_ext = None
src_ext = None
extensions = ((ext[1:] if ext[0] == '.' else ext) for ext in extensions)
for ext in extensions:
if ext in HEADER_FILE_EXTENSIONS:
if h_ext is not None:
raise ValueError("Multiple header file extensions found.")
h_ext = ext
elif ext in SOURCE_FILE_EXTENSIONS:
if src_ext is not None:
raise ValueError("Multiple source file extensions found.")
src_ext = ext
else:
raise ValueError(f"Don't know how to interpret extension '.{ext}'")
return h_ext, src_ext
def run_configurator(configurator_script: str): def run_configurator(configurator_script: str):
raise NotImplementedError() raise NotImplementedError()
def config_from_commandline(self, argv: List[str]): def config_from_commandline(argv: List[str]):
parser = ArgumentParser("pystencilssfg", parser = ArgumentParser("pystencilssfg",
description="pystencils Source File Generator", description="pystencils Source File Generator",
allow_abbrev=False) allow_abbrev=False)
...@@ -109,7 +95,7 @@ def config_from_commandline(self, argv: List[str]): ...@@ -109,7 +95,7 @@ def config_from_commandline(self, argv: List[str]):
project_config = None project_config = None
if args.file_extensions is not None: if args.file_extensions is not None:
h_ext, src_ext = get_file_extensions(args.file_extensions) h_ext, src_ext = _get_file_extensions(args.file_extensions)
else: else:
h_ext, src_ext = None, None h_ext, src_ext = None, None
...@@ -130,23 +116,44 @@ def merge_configurations(project_config: SfgConfiguration, ...@@ -130,23 +116,44 @@ def merge_configurations(project_config: SfgConfiguration,
config = DEFAULT_CONFIG config = DEFAULT_CONFIG
if project_config is not None: if project_config is not None:
config = replace(DEFAULT_CONFIG, **(project_config.asdict())) config = config.override(project_config)
if cmdline_config is not None: if cmdline_config is not None:
cmdline_dict = cmdline_config.asdict() cmdline_dict = asdict(cmdline_config)
# Commandline config completely overrides project and default config # Commandline config completely overrides project and default config
config = replace(config, **cmdline_dict) config = config.override(cmdline_config)
else: else:
cmdline_dict = {} cmdline_dict = {}
if script_config is not None: if script_config is not None:
# User config may only set values not specified on the command line # User config may only set values not specified on the command line
script_dict = script_config.asdict() script_dict = asdict(script_config)
for key, cmdline_value in cmdline_dict.items(): for key, cmdline_value in cmdline_dict.items():
if cmdline_value is not None and script_dict[key] is not None: if cmdline_value is not None and script_dict[key] is not None:
raise SfgException( raise SfgException(
f"Conflicting configuration: Parameter {key} was specified both in the script and on the command line.") f"Conflicting configuration: Parameter {key} was specified both in the script and on the command line.")
config = replace(config, **script_dict) config = config.override(script_config)
return config return config
def _get_file_extensions(extensions: Sequence[str]):
h_ext = None
src_ext = None
extensions = ((ext[1:] if ext[0] == '.' else ext) for ext in extensions)
for ext in extensions:
if ext in HEADER_FILE_EXTENSIONS:
if h_ext is not None:
raise ValueError("Multiple header file extensions found.")
h_ext = ext
elif ext in SOURCE_FILE_EXTENSIONS:
if src_ext is not None:
raise ValueError("Multiple source file extensions found.")
src_ext = ext
else:
raise ValueError(f"Don't know how to interpret extension '.{ext}'")
return h_ext, src_ext
...@@ -21,7 +21,10 @@ from .source_components import SfgFunction, SfgHeaderInclude ...@@ -21,7 +21,10 @@ from .source_components import SfgFunction, SfgHeaderInclude
class SourceFileGenerator: class SourceFileGenerator:
def __init__(self, sfg_config: SfgConfiguration): def __init__(self, sfg_config: SfgConfiguration = None):
if sfg_config and not isinstance(sfg_config, SfgConfiguration):
raise TypeError("sfg_config is not an SfgConfiguration.")
import __main__ import __main__
scriptpath = __main__.__file__ scriptpath = __main__.__file__
scriptname = path.split(scriptpath)[1] scriptname = path.split(scriptpath)[1]
...@@ -34,7 +37,7 @@ class SourceFileGenerator: ...@@ -34,7 +37,7 @@ class SourceFileGenerator:
self._context = SfgContext(script_args, config) self._context = SfgContext(script_args, config)
from .emitters.cpu.basic_cpu import BasicCpuEmitter from .emitters.cpu.basic_cpu import BasicCpuEmitter
self._emitter = BasicCpuEmitter(self._context, basename, config.output_directory) self._emitter = BasicCpuEmitter(basename, config)
def clean_files(self): def clean_files(self):
for file in self._emitter.output_files: for file in self._emitter.output_files:
...@@ -47,7 +50,7 @@ class SourceFileGenerator: ...@@ -47,7 +50,7 @@ class SourceFileGenerator:
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None: if exc_type is None:
self._emitter.write_files() self._emitter.write_files(self._context)
class SfgContext: class SfgContext:
......
...@@ -2,15 +2,15 @@ from jinja2 import Environment, PackageLoader, StrictUndefined ...@@ -2,15 +2,15 @@ from jinja2 import Environment, PackageLoader, StrictUndefined
from os import path from os import path
from ...configuration import SfgConfiguration
from ...context import SfgContext from ...context import SfgContext
class BasicCpuEmitter: class BasicCpuEmitter:
def __init__(self, ctx: SfgContext, basename: str, output_directory: str): def __init__(self, basename: str, config: SfgConfiguration):
self._ctx = ctx
self._basename = basename self._basename = basename
self._output_directory = output_directory self._output_directory = config.output_directory
self._header_filename = basename + ".h" self._header_filename = f"{basename}.{config.header_extension}"
self._cpp_filename = basename + ".cpp" self._cpp_filename = f"{basename}.{config.source_extension}"
@property @property
def output_files(self) -> str: def output_files(self) -> str:
...@@ -19,15 +19,15 @@ class BasicCpuEmitter: ...@@ -19,15 +19,15 @@ class BasicCpuEmitter:
path.join(self._output_directory, self._cpp_filename) path.join(self._output_directory, self._cpp_filename)
) )
def write_files(self): def write_files(self, ctx: SfgContext):
jinja_context = { jinja_context = {
'ctx': self._ctx, 'ctx': ctx,
'basename': self._basename, 'basename': self._basename,
'root_namespace': self._ctx.root_namespace, 'root_namespace': ctx.root_namespace,
'public_includes': list(incl.get_code() for incl in self._ctx.includes() if not incl.private), 'public_includes': list(incl.get_code() for incl in ctx.includes() if not incl.private),
'private_includes': list(incl.get_code() for incl in self._ctx.includes() if incl.private), 'private_includes': list(incl.get_code() for incl in ctx.includes() if incl.private),
'kernel_namespaces': list(self._ctx.kernel_namespaces()), 'kernel_namespaces': list(ctx.kernel_namespaces()),
'functions': list(self._ctx.functions()) 'functions': list(ctx.functions())
} }
template_name = "BasicCpu" template_name = "BasicCpu"
......
...@@ -7,7 +7,7 @@ from pystencilssfg import SourceFileGenerator ...@@ -7,7 +7,7 @@ from pystencilssfg import SourceFileGenerator
from pystencilssfg.source_concepts.cpp import std_mdspan from pystencilssfg.source_concepts.cpp import std_mdspan
with SourceFileGenerator("poisson") as sfg: with SourceFileGenerator() as sfg:
src, dst = ps.fields("src, dst(1) : double[2D]") src, dst = ps.fields("src, dst(1) : double[2D]")
h = sp.Symbol('h') h = sp.Symbol('h')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment