Skip to content
Snippets Groups Projects
Commit af279cf0 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

refactored prelude printing and file names

parent 1699c4db
No related branches found
No related tags found
No related merge requests found
Pipeline #57782 passed
......@@ -10,7 +10,7 @@ def sfg_config():
return SfgConfiguration(
header_extension='hpp',
source_extension='cpp',
impl_extension='cpp',
outer_namespace='cmake_demo',
project_info=project_info
)
......@@ -76,7 +76,7 @@ def list_files(args):
emitter = HeaderSourcePairEmitter(basename,
config.header_extension,
config.source_extension,
config.impl_extension,
config.output_directory)
print(args.sep.join(emitter.output_files), end=os.linesep if args.newline else '')
......
......@@ -14,7 +14,7 @@ from importlib import util as iutil
from .exceptions import SfgException
HEADER_FILE_EXTENSIONS = {'h', 'hpp'}
SOURCE_FILE_EXTENSIONS = {'c', 'cpp'}
IMPL_FILE_EXTENSIONS = {'c', 'cpp', '.impl.h'}
class SfgConfigSource(Enum):
......@@ -35,20 +35,57 @@ class SfgConfigException(Exception):
class SfgCodeStyle:
indent_width: int = 2
code_style: str = "LLVM"
"""Code style to be used by clang-format. Passed verbatim to `--style` argument of the clang-format CLI."""
force_clang_format: bool = False
"""If set to True, abort code generation if `clang-format` binary cannot be found."""
def indent(self, s: str):
prefix = " " * self.indent_width
return indent(s, prefix)
@dataclass
class SfgOutputSpec:
"""Name and path specification for files output by the code generator.
Filenames are constructed as `<output_directory>/<basename>.<extension>`."""
output_directory: str
"""Directory to which the generated files should be written."""
basename: str
"""Base name for output files."""
header_extension: str
"""File extension for generated header file."""
impl_extension: str
"""File extension for generated implementation file."""
def get_header_filename(self):
return f"{self.basename}.{self.header_extension}"
def get_impl_filename(self):
return f"{self.basename}.{self.impl_extension}"
def get_header_filepath(self):
return path.join(self.output_directory, self.get_header_filename())
def get_impl_filepath(self):
return path.join(self.output_directory, self.get_impl_filename())
@dataclass
class SfgConfiguration:
config_source: InitVar[SfgConfigSource | None] = None
header_extension: str | None = None
"""File extension for generated header files."""
"""File extension for generated header file."""
source_extension: str | None = None
"""File extension for generated source files."""
impl_extension: str | None = None
"""File extension for generated implementation file."""
header_only: bool | None = None
"""If set to `True`, generate only a header file without accompaning source file."""
......@@ -73,22 +110,34 @@ class SfgConfiguration:
if self.header_extension and self.header_extension[0] == '.':
self.header_extension = self.header_extension[1:]
if self.source_extension and self.source_extension[0] == '.':
self.source_extension = self.source_extension[1:]
if self.impl_extension and self.impl_extension[0] == '.':
self.impl_extension = self.impl_extension[1:]
def override(self, other: SfgConfiguration):
other_dict: dict[str, Any] = {k: v for k, v in asdict(other).items() if v is not None}
return replace(self, **other_dict)
def get_output_spec(self, basename: str) -> SfgOutputSpec:
assert self.header_extension is not None
assert self.impl_extension is not None
assert self.output_directory is not None
return SfgOutputSpec(
self.output_directory,
basename,
self.header_extension,
self.impl_extension
)
DEFAULT_CONFIG = SfgConfiguration(
config_source=SfgConfigSource.DEFAULT,
header_extension='h',
source_extension='cpp',
impl_extension='cpp',
header_only=False,
outer_namespace=None,
codestyle=SfgCodeStyle(),
output_directory=""
output_directory="."
)
......@@ -145,7 +194,7 @@ def config_from_parser_args(args):
cmdline_config = SfgConfiguration(
config_source=SfgConfigSource.COMMANDLINE,
header_extension=h_ext,
source_extension=src_ext,
impl_extension=src_ext,
header_only=args.header_only,
output_directory=args.output_directory
)
......@@ -207,7 +256,7 @@ def _get_file_extensions(cfgsrc: SfgConfigSource, extensions: Sequence[str]):
if h_ext is not None:
raise SfgConfigException(cfgsrc, "Multiple header file extensions specified.")
h_ext = ext
elif ext in SOURCE_FILE_EXTENSIONS:
elif ext in IMPL_FILE_EXTENSIONS:
if src_ext is not None:
raise SfgConfigException(cfgsrc, "Multiple source file extensions specified.")
src_ext = ext
......
from typing import cast
from jinja2 import Environment, PackageLoader, StrictUndefined
from textwrap import indent
from os import path
from ..configuration import SfgOutputSpec
from ..context import SfgContext
class HeaderSourcePairEmitter:
def __init__(self,
basename: str,
header_extension: str,
impl_extension: str,
output_directory: str):
self._basename = basename
self._output_directory = cast(str, output_directory)
self._header_filename = f"{basename}.{header_extension}"
self._source_filename = f"{basename}.{impl_extension}"
def __init__(self, output_spec: SfgOutputSpec):
self._basename = output_spec.basename
self._output_directory = output_spec.output_directory
self._header_filename = output_spec.get_header_filename()
self._impl_filename = output_spec.get_impl_filename()
self._ospec = output_spec
@property
def output_files(self) -> tuple[str, str]:
return (
path.join(self._output_directory, self._header_filename),
path.join(self._output_directory, self._source_filename)
path.join(self._output_directory, self._impl_filename)
)
def write_files(self, ctx: SfgContext):
......@@ -31,9 +28,9 @@ class HeaderSourcePairEmitter:
jinja_context = {
'ctx': ctx,
'header_filename': self._header_filename,
'source_filename': self._source_filename,
'source_filename': self._impl_filename,
'basename': self._basename,
'prelude': get_prelude_comment(ctx),
'prelude_comment': ctx.prelude_comment,
'definitions': list(ctx.definitions()),
'fq_namespace': fq_namespace,
'public_includes': list(incl.get_code() for incl in ctx.includes() if not incl.private),
......@@ -55,15 +52,8 @@ class HeaderSourcePairEmitter:
header = env.get_template(f"{template_name}.tmpl.h").render(**jinja_context)
source = env.get_template(f"{template_name}.tmpl.cpp").render(**jinja_context)
with open(path.join(self._output_directory, self._header_filename), 'w') as headerfile:
with open(self._ospec.get_header_filepath(), 'w') as headerfile:
headerfile.write(header)
with open(path.join(self._output_directory, self._source_filename), 'w') as cppfile:
with open(self._ospec.get_impl_filepath(), 'w') as cppfile:
cppfile.write(source)
def get_prelude_comment(ctx: SfgContext):
if not ctx.prelude_comment:
return ""
return "/*\n" + indent(ctx.prelude_comment, "* ", predicate=lambda _: True) + "*/\n"
from jinja2 import pass_context
from textwrap import indent
from pystencils.astnodes import KernelFunction
from pystencils import Backend
......@@ -7,6 +8,13 @@ from pystencils.backends import generate_c
from pystencilssfg.source_components import SfgFunction
def format_prelude_comment(prelude_comment: str):
if not prelude_comment:
return ""
return "/*\n" + indent(prelude_comment, "* ", predicate=lambda _: True) + "*/\n"
@pass_context
def generate_kernel_definition(ctx, ast: KernelFunction):
return generate_c(ast, dialect=Backend.C)
......@@ -23,6 +31,7 @@ def generate_function_body(func: SfgFunction):
def add_filters_to_jinja(jinja_env):
jinja_env.filters['format_prelude_comment'] = format_prelude_comment
jinja_env.filters['generate_kernel_definition'] = generate_kernel_definition
jinja_env.filters['generate_function_parameter_list'] = generate_function_parameter_list
jinja_env.filters['generate_function_body'] = generate_function_body
{{ prelude }}
{{ prelude_comment | format_prelude_comment }}
#include "{{header_filename}}"
......
{{ prelude }}
{{ prelude_comment | format_prelude_comment }}
#pragma once
......
......@@ -27,10 +27,7 @@ class SourceFileGenerator:
self._context = SfgContext(config.outer_namespace, config.codestyle, argv=script_args)
from .emitters import HeaderSourcePairEmitter
self._emitter = HeaderSourcePairEmitter(basename,
config.header_extension,
config.source_extension,
config.output_directory)
self._emitter = HeaderSourcePairEmitter(config.get_output_spec(basename))
def clean_files(self):
for file in self._emitter.output_files:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment