Skip to content
Snippets Groups Projects
Commit 271404a6 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

update and fix large parts of the test suite

parent c9683d17
No related branches found
No related tags found
1 merge request!17Improved Source File and Code Structure Modelling
Pipeline #73666 failed
Showing
with 104 additions and 76 deletions
...@@ -2,21 +2,30 @@ import pytest ...@@ -2,21 +2,30 @@ import pytest
from os import path from os import path
@pytest.fixture(autouse=True)
def prepare_doctest_namespace(doctest_namespace):
from pystencilssfg import SfgContext, SfgComposer
from pystencilssfg import lang
# Place a composer object in the environment for doctests
sfg = SfgComposer(SfgContext())
doctest_namespace["sfg"] = sfg
doctest_namespace["lang"] = lang
DATA_DIR = path.join(path.split(__file__)[0], "tests/data") DATA_DIR = path.join(path.split(__file__)[0], "tests/data")
@pytest.fixture @pytest.fixture
def sample_config_module(): def sample_config_module():
return path.join(DATA_DIR, "project_config.py") return path.join(DATA_DIR, "project_config.py")
@pytest.fixture
def sfg():
from pystencilssfg import SfgContext, SfgComposer
from pystencilssfg.ir import SfgSourceFile, SfgSourceFileType
return SfgComposer(
SfgContext(
header_file=SfgSourceFile("", SfgSourceFileType.HEADER),
impl_file=SfgSourceFile("", SfgSourceFileType.TRANSLATION_UNIT),
)
)
@pytest.fixture(autouse=True)
def prepare_doctest_namespace(doctest_namespace, sfg):
from pystencilssfg import lang
doctest_namespace["sfg"] = sfg
doctest_namespace["lang"] = lang
import sys import sys
import os import os
from os import path from os import path
from pathlib import Path
from typing import NoReturn from typing import NoReturn
from argparse import ArgumentParser, BooleanOptionalAction from argparse import ArgumentParser, BooleanOptionalAction
from .config import CommandLineParameters, SfgConfigException, OutputMode from .config import CommandLineParameters, SfgConfigException
def add_newline_arg(parser): def add_newline_arg(parser):
...@@ -81,13 +80,7 @@ def list_files(args) -> NoReturn: ...@@ -81,13 +80,7 @@ def list_files(args) -> NoReturn:
_, scriptname = path.split(args.codegen_script) _, scriptname = path.split(args.codegen_script)
basename = path.splitext(scriptname)[0] basename = path.splitext(scriptname)[0]
output_dir: Path = config.get_option("output_directory") output_files = config._get_output_files(basename)
header_ext = config.extensions.get_option("header")
output_files = [output_dir / f"{basename}.{header_ext}"]
if config.output_mode != OutputMode.HEADER_ONLY:
impl_ext = config.extensions.get_option("impl")
output_files.append(output_dir / f"{basename}.{impl_ext}")
print( print(
args.sep.join(str(of) for of in output_files), args.sep.join(str(of) for of in output_files),
......
...@@ -173,6 +173,26 @@ class SfgConfig(ConfigBase): ...@@ -173,6 +173,26 @@ class SfgConfig(ConfigBase):
def _validate_output_directory(self, pth: str | Path) -> Path: def _validate_output_directory(self, pth: str | Path) -> Path:
return Path(pth) return Path(pth)
def _get_output_files(self, basename: str):
output_dir: Path = self.get_option("output_directory")
header_ext = self.extensions.get_option("header")
impl_ext = self.extensions.get_option("impl")
output_files = [output_dir / f"{basename}.{header_ext}"]
output_mode = self.get_option("output_mode")
if impl_ext is None:
match output_mode:
case OutputMode.INLINE:
impl_ext = "ipp"
case OutputMode.STANDALONE:
impl_ext = "cpp"
if impl_ext is not None:
output_files.append(output_dir / f"{basename}.{impl_ext}")
return tuple(output_files)
class CommandLineParameters: class CommandLineParameters:
@staticmethod @staticmethod
......
...@@ -102,7 +102,7 @@ class SfgCursor: ...@@ -102,7 +102,7 @@ class SfgCursor:
self._cur_namespace: SfgNamespace = namespace self._cur_namespace: SfgNamespace = namespace
self._loc: dict[SfgSourceFile, list[SfgNamespaceElement]] self._loc: dict[SfgSourceFile, list[SfgNamespaceElement]] = dict()
for f in self._ctx.files: for f in self._ctx.files:
if self._cur_namespace is not None: if self._cur_namespace is not None:
block = SfgNamespaceBlock(self._cur_namespace) block = SfgNamespaceBlock(self._cur_namespace)
......
...@@ -35,12 +35,12 @@ class SfgFilePrinter: ...@@ -35,12 +35,12 @@ class SfgFilePrinter:
code = "" code = ""
if file.file_type == SfgSourceFileType.HEADER: if file.file_type == SfgSourceFileType.HEADER:
code += "#pragma once\n" code += "#pragma once\n\n"
if file.prelude: if file.prelude:
comment = "/**\n" comment = "/**\n"
comment += indent(file.prelude, " * ") comment += indent(file.prelude, " * ")
comment += "\n */\n\n" comment += " */\n\n"
code += comment code += comment
...@@ -54,6 +54,7 @@ class SfgFilePrinter: ...@@ -54,6 +54,7 @@ class SfgFilePrinter:
# Here begins the actual code # Here begins the actual code
code += "\n\n".join(self.visit(elem) for elem in file.elements) code += "\n\n".join(self.visit(elem) for elem in file.elements)
code += "\n" code += "\n"
return code return code
def visit( def visit(
...@@ -73,6 +74,8 @@ class SfgFilePrinter: ...@@ -73,6 +74,8 @@ class SfgFilePrinter:
return self.visit_decl(entity, inclass) return self.visit_decl(entity, inclass)
case SfgEntityDef(entity): case SfgEntityDef(entity):
return self.visit_defin(entity, inclass) return self.visit_defin(entity, inclass)
case SfgClassBody():
return self.visit_defin(elem, inclass)
case _: case _:
assert False, "illegal code element" assert False, "illegal code element"
...@@ -173,7 +176,7 @@ class SfgFilePrinter: ...@@ -173,7 +176,7 @@ class SfgFilePrinter:
code = "" code = ""
if func.inline: if func.inline:
code += "inline " code += "inline "
code += func.return_type.c_string() code += func.return_type.c_string() + " "
params_str = ", ".join( params_str = ", ".join(
f"{param.dtype.c_string()} {param.name}" for param in func.parameters f"{param.dtype.c_string()} {param.name}" for param in func.parameters
) )
......
from pathlib import Path from pathlib import Path
from .config import SfgConfig, CommandLineParameters, OutputMode, GLOBAL_NAMESPACE from .config import (
SfgConfig,
CommandLineParameters,
OutputMode,
_GlobalNamespace,
)
from .context import SfgContext from .context import SfgContext
from .composer import SfgComposer from .composer import SfgComposer
from .emission import SfgCodeEmitter from .emission import SfgCodeEmitter
from .exceptions import SfgException from .exceptions import SfgException
from .lang import HeaderFile
class SourceFileGenerator: class SourceFileGenerator:
...@@ -26,7 +32,10 @@ class SourceFileGenerator: ...@@ -26,7 +32,10 @@ class SourceFileGenerator:
""" """
def __init__( def __init__(
self, sfg_config: SfgConfig | None = None, keep_unknown_argv: bool = False self,
sfg_config: SfgConfig | None = None,
namespace: str | None = None,
keep_unknown_argv: bool = False,
): ):
if sfg_config and not isinstance(sfg_config, SfgConfig): if sfg_config and not isinstance(sfg_config, SfgConfig):
raise TypeError("sfg_config is not an SfgConfiguration.") raise TypeError("sfg_config is not an SfgConfiguration.")
...@@ -68,13 +77,13 @@ class SourceFileGenerator: ...@@ -68,13 +77,13 @@ class SourceFileGenerator:
self._output_mode: OutputMode = config.get_option("output_mode") self._output_mode: OutputMode = config.get_option("output_mode")
self._output_dir: Path = config.get_option("output_directory") self._output_dir: Path = config.get_option("output_directory")
self._header_ext: str = config.extensions.get_option("header")
self._impl_ext: str = config.extensions.get_option("impl") output_files = config._get_output_files(basename)
from .ir import SfgSourceFile, SfgSourceFileType from .ir import SfgSourceFile, SfgSourceFileType
self._header_file = SfgSourceFile( self._header_file = SfgSourceFile(
f"{basename}.{self._header_ext}", SfgSourceFileType.HEADER output_files[0].name, SfgSourceFileType.HEADER
) )
self._impl_file: SfgSourceFile | None self._impl_file: SfgSourceFile | None
...@@ -83,17 +92,29 @@ class SourceFileGenerator: ...@@ -83,17 +92,29 @@ class SourceFileGenerator:
self._impl_file = None self._impl_file = None
case OutputMode.STANDALONE: case OutputMode.STANDALONE:
self._impl_file = SfgSourceFile( self._impl_file = SfgSourceFile(
f"{basename}.{self._impl_ext}", SfgSourceFileType.TRANSLATION_UNIT output_files[1].name, SfgSourceFileType.TRANSLATION_UNIT
)
self._impl_file.includes.append(
HeaderFile.parse(self._header_file.name)
) )
case OutputMode.INLINE: case OutputMode.INLINE:
self._impl_file = SfgSourceFile( self._impl_file = SfgSourceFile(
f"{basename}.{self._impl_ext}", SfgSourceFileType.HEADER output_files[1].name, SfgSourceFileType.HEADER
) )
outer_namespace: str | _GlobalNamespace = config.get_option("outer_namespace")
match (outer_namespace, namespace):
case [_GlobalNamespace(), None]:
namespace = None
case [_GlobalNamespace(), nspace] | [nspace, None]:
namespace = nspace
case [outer, inner]:
namespace = f"{outer}::{inner}"
self._context = SfgContext( self._context = SfgContext(
self._header_file, self._header_file,
self._impl_file, self._impl_file,
None if config.outer_namespace is GLOBAL_NAMESPACE else config.outer_namespace, # type: ignore namespace,
config.codestyle, config.codestyle,
argv=script_args, argv=script_args,
project_info=cli_params.get_project_info(), project_info=cli_params.get_project_info(),
...@@ -119,6 +140,10 @@ class SourceFileGenerator: ...@@ -119,6 +140,10 @@ class SourceFileGenerator:
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None: if exc_type is None:
if self._output_mode == OutputMode.INLINE:
assert self._impl_file is not None
self._header_file.elements.append(f'#include "{self._impl_file.name}"')
# Collect header files for inclusion # Collect header files for inclusion
# from .ir import collect_includes # from .ir import collect_includes
......
...@@ -5,8 +5,8 @@ import pystencils as ps ...@@ -5,8 +5,8 @@ import pystencils as ps
from pystencilssfg import SfgContext from pystencilssfg import SfgContext
def test_parallel_for_1_kernels(): def test_parallel_for_1_kernels(sfg):
sfg = sycl.SyclComposer(SfgContext()) sfg = sycl.SyclComposer(sfg)
data_type = "double" data_type = "double"
dim = 2 dim = 2
f, g, h, i = ps.fields(f"f,g,h,i:{data_type}[{dim}D]") f, g, h, i = ps.fields(f"f,g,h,i:{data_type}[{dim}D]")
...@@ -24,8 +24,8 @@ def test_parallel_for_1_kernels(): ...@@ -24,8 +24,8 @@ def test_parallel_for_1_kernels():
) )
def test_parallel_for_2_kernels(): def test_parallel_for_2_kernels(sfg):
sfg = sycl.SyclComposer(SfgContext()) sfg = sycl.SyclComposer(sfg)
data_type = "double" data_type = "double"
dim = 2 dim = 2
f, g, h, i = ps.fields(f"f,g,h,i:{data_type}[{dim}D]") f, g, h, i = ps.fields(f"f,g,h,i:{data_type}[{dim}D]")
...@@ -43,8 +43,8 @@ def test_parallel_for_2_kernels(): ...@@ -43,8 +43,8 @@ def test_parallel_for_2_kernels():
) )
def test_parallel_for_2_kernels_fail(): def test_parallel_for_2_kernels_fail(sfg):
sfg = sycl.SyclComposer(SfgContext()) sfg = sycl.SyclComposer(sfg)
data_type = "double" data_type = "double"
dim = 2 dim = 2
f, g = ps.fields(f"f,g:{data_type}[{dim}D]") f, g = ps.fields(f"f,g:{data_type}[{dim}D]")
......
import pytest import pytest
from pathlib import Path
from pystencilssfg.config import ( from pystencilssfg.config import (
SfgConfig, SfgConfig,
...@@ -86,7 +87,7 @@ def test_from_commandline(sample_config_module): ...@@ -86,7 +87,7 @@ def test_from_commandline(sample_config_module):
cli_args = CommandLineParameters(args) cli_args = CommandLineParameters(args)
cfg = cli_args.get_config() cfg = cli_args.get_config()
assert cfg.output_directory == ".out" assert cfg.output_directory == Path(".out")
assert cfg.extensions.header == "h++" assert cfg.extensions.header == "h++"
assert cfg.extensions.impl == "c++" assert cfg.extensions.impl == "c++"
...@@ -100,7 +101,7 @@ def test_from_commandline(sample_config_module): ...@@ -100,7 +101,7 @@ def test_from_commandline(sample_config_module):
assert cfg.clang_format.code_style == "llvm" assert cfg.clang_format.code_style == "llvm"
assert cfg.clang_format.skip is True assert cfg.clang_format.skip is True
assert ( assert (
cfg.output_directory == "gen_sources" cfg.output_directory == Path("gen_sources")
) # value from config module overridden by commandline ) # value from config module overridden by commandline
assert cfg.outer_namespace == "myproject" assert cfg.outer_namespace == "myproject"
assert cfg.extensions.header == "hpp" assert cfg.extensions.header == "hpp"
......
...@@ -4,12 +4,10 @@ from pystencilssfg import SourceFileGenerator, SfgConfig ...@@ -4,12 +4,10 @@ from pystencilssfg import SourceFileGenerator, SfgConfig
cfg = SfgConfig() cfg = SfgConfig()
cfg.clang_format.skip = True cfg.clang_format.skip = True
with SourceFileGenerator(cfg) as sfg: with SourceFileGenerator(cfg, namespace="awesome") as sfg:
sfg.prelude("Expect the unexpected, and you shall never be surprised.") sfg.prelude("Expect the unexpected, and you shall never be surprised.")
sfg.include("<iostream>") sfg.include("<iostream>")
sfg.include("config.h") sfg.include("config.h")
sfg.namespace("awesome")
sfg.code("#define PI 3.1415") sfg.code("#define PI 3.1415")
sfg.code("using namespace std;") sfg.code("using namespace std;")
from pystencilssfg import SourceFileGenerator from pystencilssfg import SourceFileGenerator
from pystencils.types import PsCustomType from pystencils.types import PsCustomType
with SourceFileGenerator() as sfg: with SourceFileGenerator(namespace="gen") as sfg:
sfg.namespace("gen")
sfg.include("<iostream>") sfg.include("<iostream>")
sfg.code(r"enum class Noodles { RIGATONI, RAMEN, SPAETZLE, SPAGHETTI };") sfg.code(r"enum class Noodles { RIGATONI, RAMEN, SPAETZLE, SPAGHETTI };")
......
...@@ -7,9 +7,7 @@ from pystencilssfg.lang.cpp.std import mdspan ...@@ -7,9 +7,7 @@ from pystencilssfg.lang.cpp.std import mdspan
mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>") mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>")
with SourceFileGenerator() as sfg: with SourceFileGenerator(namespace="gen") as sfg:
sfg.namespace("gen")
u_src, u_dst, f = fields("u_src, u_dst, f(1) : double[2D]", layout="fzyx") u_src, u_dst, f = fields("u_src, u_dst, f(1) : double[2D]", layout="fzyx")
h = sp.Symbol("h") h = sp.Symbol("h")
......
...@@ -5,8 +5,7 @@ from pystencilssfg.lang import strip_ptr_ref ...@@ -5,8 +5,7 @@ from pystencilssfg.lang import strip_ptr_ref
std.mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>") std.mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>")
with SourceFileGenerator() as sfg: with SourceFileGenerator(namespace="gen") as sfg:
sfg.namespace("gen")
sfg.include("<cassert>") sfg.include("<cassert>")
def check_layout(field: ps.Field, mdspan: std.mdspan): def check_layout(field: ps.Field, mdspan: std.mdspan):
......
...@@ -43,8 +43,7 @@ def lbm_stream(sfg: SfgComposer, field_layout: str, layout_policy: str): ...@@ -43,8 +43,7 @@ def lbm_stream(sfg: SfgComposer, field_layout: str, layout_policy: str):
) )
with SourceFileGenerator() as sfg: with SourceFileGenerator(namespace="gen") as sfg:
sfg.namespace("gen")
sfg.include("<cassert>") sfg.include("<cassert>")
sfg.include("<array>") sfg.include("<array>")
......
...@@ -2,7 +2,7 @@ from pystencils import TypedSymbol, fields, kernel ...@@ -2,7 +2,7 @@ from pystencils import TypedSymbol, fields, kernel
from pystencilssfg import SourceFileGenerator from pystencilssfg import SourceFileGenerator
with SourceFileGenerator() as sfg: with SourceFileGenerator(namespace="gen") as sfg:
N = 10 N = 10
α = TypedSymbol("alpha", "float32") α = TypedSymbol("alpha", "float32")
src, dst = fields(f"src, dst: float32[{N}]") src, dst = fields(f"src, dst: float32[{N}]")
...@@ -13,7 +13,6 @@ with SourceFileGenerator() as sfg: ...@@ -13,7 +13,6 @@ with SourceFileGenerator() as sfg:
khandle = sfg.kernels.create(scale) khandle = sfg.kernels.create(scale)
sfg.namespace("gen")
sfg.code(f"constexpr int N = {N};") sfg.code(f"constexpr int N = {N};")
sfg.klass("Scale")( sfg.klass("Scale")(
......
...@@ -5,9 +5,7 @@ from pystencilssfg import SourceFileGenerator ...@@ -5,9 +5,7 @@ from pystencilssfg import SourceFileGenerator
from pystencilssfg.lang.cpp import std from pystencilssfg.lang.cpp import std
with SourceFileGenerator() as sfg: with SourceFileGenerator(namespace="StlContainers1D::gen") as sfg:
sfg.namespace("StlContainers1D::gen")
src, dst = ps.fields("src, dst: double[1D]") src, dst = ps.fields("src, dst: double[1D]")
asms = [ asms = [
......
...@@ -4,9 +4,8 @@ from pystencilssfg import SourceFileGenerator ...@@ -4,9 +4,8 @@ from pystencilssfg import SourceFileGenerator
import pystencilssfg.extensions.sycl as sycl import pystencilssfg.extensions.sycl as sycl
with SourceFileGenerator() as sfg: with SourceFileGenerator(namespace="gen") as sfg:
sfg = sycl.SyclComposer(sfg) sfg = sycl.SyclComposer(sfg)
sfg.namespace("gen")
u_src, u_dst, f = ps.fields("u_src, u_dst, f : double[2D]", layout="fzyx") u_src, u_dst, f = ps.fields("u_src, u_dst, f : double[2D]", layout="fzyx")
h = sp.Symbol("h") h = sp.Symbol("h")
......
...@@ -2,7 +2,6 @@ import sympy as sp ...@@ -2,7 +2,6 @@ import sympy as sp
from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type
from pystencils.types import PsCustomType from pystencils.types import PsCustomType
from pystencilssfg import SfgContext, SfgComposer
from pystencilssfg.composer import make_sequence from pystencilssfg.composer import make_sequence
from pystencilssfg.lang import IFieldExtraction, AugExpr from pystencilssfg.lang import IFieldExtraction, AugExpr
...@@ -11,10 +10,7 @@ from pystencilssfg.ir import SfgStatements, SfgSequence ...@@ -11,10 +10,7 @@ from pystencilssfg.ir import SfgStatements, SfgSequence
from pystencilssfg.ir.postprocessing import CallTreePostProcessing from pystencilssfg.ir.postprocessing import CallTreePostProcessing
def test_live_vars(): def test_live_vars(sfg):
ctx = SfgContext()
sfg = SfgComposer(ctx)
f, g = fields("f, g(2): double[2D]") f, g = fields("f, g(2): double[2D]")
x, y = [TypedSymbol(n, "double") for n in "xy"] x, y = [TypedSymbol(n, "double") for n in "xy"]
z = sp.Symbol("z") z = sp.Symbol("z")
...@@ -42,10 +38,7 @@ def test_live_vars(): ...@@ -42,10 +38,7 @@ def test_live_vars():
assert free_vars == expected assert free_vars == expected
def test_find_sympy_symbols(): def test_find_sympy_symbols(sfg):
ctx = SfgContext()
sfg = SfgComposer(ctx)
f, g = fields("f, g(2): double[2D]") f, g = fields("f, g(2): double[2D]")
x, y, z = sp.symbols("x, y, z") x, y, z = sp.symbols("x, y, z")
...@@ -94,7 +87,7 @@ class DemoFieldExtraction(IFieldExtraction): ...@@ -94,7 +87,7 @@ class DemoFieldExtraction(IFieldExtraction):
return AugExpr.format("{}.stride({})", self.obj, coordinate) return AugExpr.format("{}.stride({})", self.obj, coordinate)
def test_field_extraction(): def test_field_extraction(sfg):
sx, sy, tx, ty = [ sx, sy, tx, ty = [
TypedSymbol(n, create_type("int64")) for n in ("sx", "sy", "tx", "ty") TypedSymbol(n, create_type("int64")) for n in ("sx", "sy", "tx", "ty")
] ]
...@@ -104,8 +97,6 @@ def test_field_extraction(): ...@@ -104,8 +97,6 @@ def test_field_extraction():
def set_constant(): def set_constant():
f.center @= 13.2 f.center @= 13.2
sfg = SfgComposer(SfgContext())
khandle = sfg.kernels.create(set_constant) khandle = sfg.kernels.create(set_constant)
extraction = DemoFieldExtraction("f") extraction = DemoFieldExtraction("f")
...@@ -129,7 +120,7 @@ def test_field_extraction(): ...@@ -129,7 +120,7 @@ def test_field_extraction():
assert stmt.code_string == line assert stmt.code_string == line
def test_duplicate_field_shapes(): def test_duplicate_field_shapes(sfg):
N, tx, ty = [TypedSymbol(n, create_type("int64")) for n in ("N", "tx", "ty")] N, tx, ty = [TypedSymbol(n, create_type("int64")) for n in ("N", "tx", "ty")]
f = Field("f", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty)) f = Field("f", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty))
g = Field("g", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty)) g = Field("g", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty))
...@@ -138,8 +129,6 @@ def test_duplicate_field_shapes(): ...@@ -138,8 +129,6 @@ def test_duplicate_field_shapes():
def set_constant(): def set_constant():
f.center @= g.center(0) f.center @= g.center(0)
sfg = SfgComposer(SfgContext())
khandle = sfg.kernels.create(set_constant) khandle = sfg.kernels.create(set_constant)
call_tree = make_sequence( call_tree = make_sequence(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment