Skip to content
Snippets Groups Projects
Commit 2edd363e authored by Frederik Hennig's avatar Frederik Hennig
Browse files

More examples for composer. Fix generator script tests ground-truth comparison.

parent a662f194
No related branches found
No related tags found
No related merge requests found
from .configuration import SfgConfiguration, SfgOutputMode from .configuration import SfgConfiguration, SfgOutputMode, SfgCodeStyle
from .generator import SourceFileGenerator from .generator import SourceFileGenerator
from .composer import SfgComposer from .composer import SfgComposer
from .context import SfgContext from .context import SfgContext
...@@ -9,6 +9,7 @@ __all__ = [ ...@@ -9,6 +9,7 @@ __all__ = [
"SfgComposer", "SfgComposer",
"SfgConfiguration", "SfgConfiguration",
"SfgOutputMode", "SfgOutputMode",
"SfgCodeStyle",
"SfgContext", "SfgContext",
"AugExpr", "AugExpr",
] ]
......
...@@ -96,22 +96,63 @@ class SfgBasicComposer(SfgIComposer): ...@@ -96,22 +96,63 @@ class SfgBasicComposer(SfgIComposer):
The string should not contain C/C++ comment delimiters, since these will be added automatically The string should not contain C/C++ comment delimiters, since these will be added automatically
during code generation. during code generation.
:Example:
>>> sfg.prelude("This file was generated using pystencils-sfg; do not modify it directly!")
will appear in the generated files as
.. code-block:: C++
/*
* This file was generated using pystencils-sfg; do not modify it directly!
*/
""" """
self._ctx.append_to_prelude(content) self._ctx.append_to_prelude(content)
def define(self, *definitions: str): def define(self, *definitions: str):
"""Add custom definitions to the generated header file.""" """Add custom definitions to the generated header file.
Each string passed to this method will be printed out directly into the generated header file.
:Example:
>>> sfg.define("#define PI 3.14 // more than enough for engineers")
will appear as
.. code-block:: C++
#define PI 3.14 // more than enough for engineers
"""
for d in definitions: for d in definitions:
self._ctx.add_definition(d) self._ctx.add_definition(d)
def define_once(self, *definitions: str): def define_once(self, *definitions: str):
"""Same as `define`, but only adds definitions only if the same code string was not already added.""" """Same as `define`, but only adds definitions if the same code string was not already added."""
for definition in definitions: for definition in definitions:
if all(d != definition for d in self._ctx.definitions()): if all(d != definition for d in self._ctx.definitions()):
self._ctx.add_definition(definition) self._ctx.add_definition(definition)
def namespace(self, namespace: str): def namespace(self, namespace: str):
"""Set the inner code namespace. Throws an exception if a namespace was already set.""" """Set the inner code namespace. Throws an exception if a namespace was already set.
:Example:
After adding the following to your generator script:
>>> sfg.namespace("codegen_is_awesome")
All generated code will be placed within that namespace:
.. code-block:: C++
namespace codegen_is_awesome {
/* all generated code */
}
"""
self._ctx.set_namespace(namespace) self._ctx.set_namespace(namespace)
def generate(self, generator: CustomGenerator): def generate(self, generator: CustomGenerator):
...@@ -142,9 +183,21 @@ class SfgBasicComposer(SfgIComposer): ...@@ -142,9 +183,21 @@ class SfgBasicComposer(SfgIComposer):
"""Include a header file. """Include a header file.
Args: Args:
header_file: Path to the header file. Enclose in `<>` for a system header. header_file: Path to the header file. Enclose in ``<>`` for a system header.
private: If `True`, in header-implementation code generation, the header file is private: If ``True``, in header-implementation code generation, the header file is
only included in the implementation file. only included in the implementation file.
:Example:
>>> sfg.include("<vector>")
>>> sfg.include("custom.h")
will be printed as
.. code-block:: C++
#include <vector>
#include "custom.h"
""" """
self._ctx.add_include(SfgHeaderInclude.parse(header_file, private)) self._ctx.add_include(SfgHeaderInclude.parse(header_file, private))
......
...@@ -40,7 +40,10 @@ class SfgCodeStyle: ...@@ -40,7 +40,10 @@ class SfgCodeStyle:
""" """
force_clang_format: bool = False force_clang_format: bool = False
"""If set to True, abort code generation if `clang-format` binary cannot be found.""" """If set to True, abort code generation if ``clang-format`` binary cannot be found."""
skip_clang_format: bool = False
"""If set to True, skip formatting using ``clang-format``."""
clang_format_binary: str = "clang-format" clang_format_binary: str = "clang-format"
"""Path to the clang-format executable""" """Path to the clang-format executable"""
......
...@@ -69,7 +69,7 @@ class SfgContext: ...@@ -69,7 +69,7 @@ class SfgContext:
# Source Components # Source Components
self._prelude: str = "" self._prelude: str = ""
self._includes: set[SfgHeaderInclude] = set() self._includes: list[SfgHeaderInclude] = []
self._definitions: list[str] = [] self._definitions: list[str] = []
self._kernel_namespaces = { self._kernel_namespaces = {
self._default_kernel_namespace.name: self._default_kernel_namespace self._default_kernel_namespace.name: self._default_kernel_namespace
...@@ -79,10 +79,6 @@ class SfgContext: ...@@ -79,10 +79,6 @@ class SfgContext:
self._declarations_ordered: list[str | SfgFunction | SfgClass] = list() self._declarations_ordered: list[str | SfgFunction | SfgClass] = list()
# Standard stuff
self.add_include(SfgHeaderInclude("cstdint", system_header=True))
self.add_definition("#define RESTRICT __restrict__")
@property @property
def argv(self) -> Sequence[str]: def argv(self) -> Sequence[str]:
"""If this context was created by a `pystencilssfg.SourceFileGenerator`, provides the command """If this context was created by a `pystencilssfg.SourceFileGenerator`, provides the command
...@@ -159,7 +155,7 @@ class SfgContext: ...@@ -159,7 +155,7 @@ class SfgContext:
yield from self._includes yield from self._includes
def add_include(self, include: SfgHeaderInclude): def add_include(self, include: SfgHeaderInclude):
self._includes.add(include) self._includes.append(include)
def definitions(self) -> Generator[str, None, None]: def definitions(self) -> Generator[str, None, None]:
"""Definitions are arbitrary custom lines of code.""" """Definitions are arbitrary custom lines of code."""
......
...@@ -24,6 +24,9 @@ def invoke_clang_format(code: str, codestyle: SfgCodeStyle) -> str: ...@@ -24,6 +24,9 @@ def invoke_clang_format(code: str, codestyle: SfgCodeStyle) -> str:
be executed (binary not found, or error during exection), the function will be executed (binary not found, or error during exection), the function will
throw an exception. throw an exception.
""" """
if codestyle.skip_clang_format:
return code
args = [codestyle.clang_format_binary, f"--style={codestyle.code_style}"] args = [codestyle.clang_format_binary, f"--style={codestyle.code_style}"]
if not shutil.which(codestyle.clang_format_binary): if not shutil.which(codestyle.clang_format_binary):
......
...@@ -178,14 +178,6 @@ class SfgHeaderPrinter(SfgGeneralPrinter): ...@@ -178,14 +178,6 @@ class SfgHeaderPrinter(SfgGeneralPrinter):
return code return code
def delimiter(content):
return f"""\
/*************************************************************************************
* {content}
*************************************************************************************/
"""
class SfgImplPrinter(SfgGeneralPrinter): class SfgImplPrinter(SfgGeneralPrinter):
def __init__( def __init__(
self, ctx: SfgContext, output_spec: SfgOutputSpec, inline_impl: bool = False self, ctx: SfgContext, output_spec: SfgOutputSpec, inline_impl: bool = False
...@@ -219,11 +211,8 @@ class SfgImplPrinter(SfgGeneralPrinter): ...@@ -219,11 +211,8 @@ class SfgImplPrinter(SfgGeneralPrinter):
parts = interleave( parts = interleave(
chain( chain(
[delimiter("Kernels")],
ctx.kernel_namespaces(), ctx.kernel_namespaces(),
[delimiter("Functions")],
ctx.functions(), ctx.functions(),
[delimiter("Class Methods")],
ctx.classes(), ctx.classes(),
), ),
repeat(SfgEmptyLines(1)), repeat(SfgEmptyLines(1)),
......
...@@ -58,6 +58,11 @@ class SourceFileGenerator: ...@@ -58,6 +58,11 @@ class SourceFileGenerator:
project_info=config.project_info, project_info=config.project_info,
) )
from pystencilssfg.ir import SfgHeaderInclude
self._context.add_include(SfgHeaderInclude("cstdint", system_header=True))
self._context.add_definition("#define RESTRICT __restrict__")
self._emitter: AbstractEmitter self._emitter: AbstractEmitter
match config.output_mode: match config.output_mode:
case SfgOutputMode.HEADER_ONLY: case SfgOutputMode.HEADER_ONLY:
......
#pragma once
#include <cstdint> #include <cstdint>
#define RESTRICT __restrict__
class Point { class Point {
public: public:
const int64_t & getX() const { const int64_t & getX() const {
......
/*
* Expect the unexpected, and you shall never be surprised.
*/
#pragma once
#include <cstdint>
#include <iostream>
#include "config.h"
namespace awesome {
#define RESTRICT __restrict__
#define PI 3.1415
using namespace std;
}
...@@ -6,7 +6,7 @@ with SourceFileGenerator() as sfg: ...@@ -6,7 +6,7 @@ with SourceFileGenerator() as sfg:
sfg.klass("Point")( sfg.klass("Point")(
sfg.public( sfg.public(
sfg.method("getX", returns="const int64_t &", const=True)( sfg.method("getX", returns="const int64_t &", const=True, inline=True)(
"return this->x;" "return this->x;"
) )
), ),
......
from pystencilssfg import SourceFileGenerator, SfgConfiguration, SfgCodeStyle
# Do not use clang-format, since it reorders headers
cfg = SfgConfiguration(
codestyle=SfgCodeStyle(skip_clang_format=True)
)
with SourceFileGenerator(cfg) as sfg:
sfg.prelude("Expect the unexpected, and you shall never be surprised.")
sfg.include("<iostream>")
sfg.include("config.h")
sfg.namespace("awesome")
sfg.define("#define PI 3.1415")
sfg.define("using namespace std;")
...@@ -14,21 +14,25 @@ EXPECTED_DIR = path.join(THIS_DIR, "expected") ...@@ -14,21 +14,25 @@ EXPECTED_DIR = path.join(THIS_DIR, "expected")
@dataclass @dataclass
class ScriptInfo: class ScriptInfo:
@staticmethod
def make(name, *args, **kwargs):
return pytest.param(ScriptInfo(name, *args, **kwargs), id=f"{name}.py")
script_name: str script_name: str
"""Name of the generator script, without .py-extension. """Name of the generator script, without .py-extension.
Generator scripts must be located in the ``scripts`` folder. Generator scripts must be located in the ``scripts`` folder.
""" """
expected_outputs: tuple[str, ...] expected_outputs: tuple[str, ...]
"""List of file extensions expected to be emitted by the generator script. """List of file extensions expected to be emitted by the generator script.
Output files will all be placed in the ``out`` folder. Output files will all be placed in the ``out`` folder.
""" """
compilable_output: str | None = None compilable_output: str | None = None
"""File extension of the output file that can be compiled. """File extension of the output file that can be compiled.
If this is set, and the expected file exists, the ``compile_cmd`` will be If this is set, and the expected file exists, the ``compile_cmd`` will be
executed to check for error-free compilation of the output. executed to check for error-free compilation of the output.
""" """
...@@ -36,11 +40,15 @@ class ScriptInfo: ...@@ -36,11 +40,15 @@ class ScriptInfo:
compile_cmd: str = f"g++ --std=c++17 -I {THIS_DIR}/deps/mdspan/include" compile_cmd: str = f"g++ --std=c++17 -I {THIS_DIR}/deps/mdspan/include"
"""Command to be invoked to compile the generated source file.""" """Command to be invoked to compile the generated source file."""
def __repr__(self) -> str:
return self.script_name
SCRIPTS = [ SCRIPTS = [
ScriptInfo("SimpleJacobi", ("h", "cpp"), compilable_output="cpp"), ScriptInfo.make("Structural", ("h", "cpp")),
ScriptInfo("SimpleClasses", ("h", "cpp")), ScriptInfo.make("SimpleJacobi", ("h", "cpp"), compilable_output="cpp"),
ScriptInfo("Variables", ("h", "cpp"), compilable_output="cpp"), ScriptInfo.make("SimpleClasses", ("h", "cpp")),
ScriptInfo.make("Variables", ("h", "cpp"), compilable_output="cpp"),
] ]
...@@ -92,7 +100,7 @@ def test_generator_script(script_info: ScriptInfo): ...@@ -92,7 +100,7 @@ def test_generator_script(script_info: ScriptInfo):
# Strip whitespace # Strip whitespace
expected = "".join(expected.split()) expected = "".join(expected.split())
actual = "".join(expected.split()) actual = "".join(actual.split())
assert expected == actual assert expected == actual
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment