diff --git a/conftest.py b/conftest.py index 661e722446f25fca1230c55989f00372bf66802c..287ed04129dbeb00b52fb7016de9861d7f952c11 100644 --- a/conftest.py +++ b/conftest.py @@ -2,21 +2,30 @@ import pytest from os import path -@pytest.fixture(autouse=True) -def prepare_doctest_namespace(doctest_namespace): - from pystencilssfg import SfgContext, SfgComposer - from pystencilssfg import lang - - # Place a composer object in the environment for doctests - - sfg = SfgComposer(SfgContext()) - doctest_namespace["sfg"] = sfg - doctest_namespace["lang"] = lang - - DATA_DIR = path.join(path.split(__file__)[0], "tests/data") @pytest.fixture def sample_config_module(): return path.join(DATA_DIR, "project_config.py") + + +@pytest.fixture +def sfg(): + from pystencilssfg import SfgContext, SfgComposer + from pystencilssfg.ir import SfgSourceFile, SfgSourceFileType + + return SfgComposer( + SfgContext( + header_file=SfgSourceFile("", SfgSourceFileType.HEADER), + impl_file=SfgSourceFile("", SfgSourceFileType.TRANSLATION_UNIT), + ) + ) + + +@pytest.fixture(autouse=True) +def prepare_doctest_namespace(doctest_namespace, sfg): + from pystencilssfg import lang + + doctest_namespace["sfg"] = sfg + doctest_namespace["lang"] = lang diff --git a/src/pystencilssfg/cli.py b/src/pystencilssfg/cli.py index 8cd9f4f552014d3ef4770e63c961bbe8680dcdc9..d612fbc886942f68ea4d24044d1ed5d3ee4cf56c 100644 --- a/src/pystencilssfg/cli.py +++ b/src/pystencilssfg/cli.py @@ -1,12 +1,11 @@ import sys import os from os import path -from pathlib import Path from typing import NoReturn from argparse import ArgumentParser, BooleanOptionalAction -from .config import CommandLineParameters, SfgConfigException, OutputMode +from .config import CommandLineParameters, SfgConfigException def add_newline_arg(parser): @@ -81,13 +80,7 @@ def list_files(args) -> NoReturn: _, scriptname = path.split(args.codegen_script) basename = path.splitext(scriptname)[0] - output_dir: Path = config.get_option("output_directory") - - header_ext = config.extensions.get_option("header") - output_files = [output_dir / f"{basename}.{header_ext}"] - if config.output_mode != OutputMode.HEADER_ONLY: - impl_ext = config.extensions.get_option("impl") - output_files.append(output_dir / f"{basename}.{impl_ext}") + output_files = config._get_output_files(basename) print( args.sep.join(str(of) for of in output_files), diff --git a/src/pystencilssfg/config.py b/src/pystencilssfg/config.py index 7bbcfc60fb8534e49b244c2ce06f4c00d09165d1..18aaa515143547bc03b438c5dda4e0169fc16d93 100644 --- a/src/pystencilssfg/config.py +++ b/src/pystencilssfg/config.py @@ -173,6 +173,26 @@ class SfgConfig(ConfigBase): def _validate_output_directory(self, pth: str | Path) -> Path: return Path(pth) + def _get_output_files(self, basename: str): + output_dir: Path = self.get_option("output_directory") + + header_ext = self.extensions.get_option("header") + impl_ext = self.extensions.get_option("impl") + output_files = [output_dir / f"{basename}.{header_ext}"] + output_mode = self.get_option("output_mode") + + if impl_ext is None: + match output_mode: + case OutputMode.INLINE: + impl_ext = "ipp" + case OutputMode.STANDALONE: + impl_ext = "cpp" + + if impl_ext is not None: + output_files.append(output_dir / f"{basename}.{impl_ext}") + + return tuple(output_files) + class CommandLineParameters: @staticmethod diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index f1a32517f3c775435caf135fcaa0e13fa242ad80..032c1e4ff275abb45be03de24391e38a0a9d4429 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -102,7 +102,7 @@ class SfgCursor: self._cur_namespace: SfgNamespace = namespace - self._loc: dict[SfgSourceFile, list[SfgNamespaceElement]] + self._loc: dict[SfgSourceFile, list[SfgNamespaceElement]] = dict() for f in self._ctx.files: if self._cur_namespace is not None: block = SfgNamespaceBlock(self._cur_namespace) diff --git a/src/pystencilssfg/emission/file_printer.py b/src/pystencilssfg/emission/file_printer.py index ec434689903fe4599bb86da82765e6c1dfe68768..f92fba468be8a440639f1ed245cf6bb99bde12b7 100644 --- a/src/pystencilssfg/emission/file_printer.py +++ b/src/pystencilssfg/emission/file_printer.py @@ -35,12 +35,12 @@ class SfgFilePrinter: code = "" if file.file_type == SfgSourceFileType.HEADER: - code += "#pragma once\n" + code += "#pragma once\n\n" if file.prelude: comment = "/**\n" comment += indent(file.prelude, " * ") - comment += "\n */\n\n" + comment += " */\n\n" code += comment @@ -54,6 +54,7 @@ class SfgFilePrinter: # Here begins the actual code code += "\n\n".join(self.visit(elem) for elem in file.elements) code += "\n" + return code def visit( @@ -73,6 +74,8 @@ class SfgFilePrinter: return self.visit_decl(entity, inclass) case SfgEntityDef(entity): return self.visit_defin(entity, inclass) + case SfgClassBody(): + return self.visit_defin(elem, inclass) case _: assert False, "illegal code element" @@ -173,7 +176,7 @@ class SfgFilePrinter: code = "" if func.inline: code += "inline " - code += func.return_type.c_string() + code += func.return_type.c_string() + " " params_str = ", ".join( f"{param.dtype.c_string()} {param.name}" for param in func.parameters ) diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py index b055ad518d6f198d5f8aabd8026353177ae3c70a..1191ebeea9ca3633afe1cf92635c7c55f8abe692 100644 --- a/src/pystencilssfg/generator.py +++ b/src/pystencilssfg/generator.py @@ -1,10 +1,16 @@ from pathlib import Path -from .config import SfgConfig, CommandLineParameters, OutputMode, GLOBAL_NAMESPACE +from .config import ( + SfgConfig, + CommandLineParameters, + OutputMode, + _GlobalNamespace, +) from .context import SfgContext from .composer import SfgComposer from .emission import SfgCodeEmitter from .exceptions import SfgException +from .lang import HeaderFile class SourceFileGenerator: @@ -26,7 +32,10 @@ class SourceFileGenerator: """ def __init__( - self, sfg_config: SfgConfig | None = None, keep_unknown_argv: bool = False + self, + sfg_config: SfgConfig | None = None, + namespace: str | None = None, + keep_unknown_argv: bool = False, ): if sfg_config and not isinstance(sfg_config, SfgConfig): raise TypeError("sfg_config is not an SfgConfiguration.") @@ -68,13 +77,13 @@ class SourceFileGenerator: self._output_mode: OutputMode = config.get_option("output_mode") self._output_dir: Path = config.get_option("output_directory") - self._header_ext: str = config.extensions.get_option("header") - self._impl_ext: str = config.extensions.get_option("impl") + + output_files = config._get_output_files(basename) from .ir import SfgSourceFile, SfgSourceFileType self._header_file = SfgSourceFile( - f"{basename}.{self._header_ext}", SfgSourceFileType.HEADER + output_files[0].name, SfgSourceFileType.HEADER ) self._impl_file: SfgSourceFile | None @@ -83,17 +92,29 @@ class SourceFileGenerator: self._impl_file = None case OutputMode.STANDALONE: self._impl_file = SfgSourceFile( - f"{basename}.{self._impl_ext}", SfgSourceFileType.TRANSLATION_UNIT + output_files[1].name, SfgSourceFileType.TRANSLATION_UNIT + ) + self._impl_file.includes.append( + HeaderFile.parse(self._header_file.name) ) case OutputMode.INLINE: self._impl_file = SfgSourceFile( - f"{basename}.{self._impl_ext}", SfgSourceFileType.HEADER + output_files[1].name, SfgSourceFileType.HEADER ) + outer_namespace: str | _GlobalNamespace = config.get_option("outer_namespace") + match (outer_namespace, namespace): + case [_GlobalNamespace(), None]: + namespace = None + case [_GlobalNamespace(), nspace] | [nspace, None]: + namespace = nspace + case [outer, inner]: + namespace = f"{outer}::{inner}" + self._context = SfgContext( self._header_file, self._impl_file, - None if config.outer_namespace is GLOBAL_NAMESPACE else config.outer_namespace, # type: ignore + namespace, config.codestyle, argv=script_args, project_info=cli_params.get_project_info(), @@ -119,6 +140,10 @@ class SourceFileGenerator: def __exit__(self, exc_type, exc_value, traceback): if exc_type is None: + if self._output_mode == OutputMode.INLINE: + assert self._impl_file is not None + self._header_file.elements.append(f'#include "{self._impl_file.name}"') + # Collect header files for inclusion # from .ir import collect_includes diff --git a/tests/extensions/test_sycl.py b/tests/extensions/test_sycl.py index db99278c3a00a333400a7a18882163676c389d00..0e067c8d593f4a038f2b2cdb2b21c40ab47e8bb2 100644 --- a/tests/extensions/test_sycl.py +++ b/tests/extensions/test_sycl.py @@ -5,8 +5,8 @@ import pystencils as ps from pystencilssfg import SfgContext -def test_parallel_for_1_kernels(): - sfg = sycl.SyclComposer(SfgContext()) +def test_parallel_for_1_kernels(sfg): + sfg = sycl.SyclComposer(sfg) data_type = "double" dim = 2 f, g, h, i = ps.fields(f"f,g,h,i:{data_type}[{dim}D]") @@ -24,8 +24,8 @@ def test_parallel_for_1_kernels(): ) -def test_parallel_for_2_kernels(): - sfg = sycl.SyclComposer(SfgContext()) +def test_parallel_for_2_kernels(sfg): + sfg = sycl.SyclComposer(sfg) data_type = "double" dim = 2 f, g, h, i = ps.fields(f"f,g,h,i:{data_type}[{dim}D]") @@ -43,8 +43,8 @@ def test_parallel_for_2_kernels(): ) -def test_parallel_for_2_kernels_fail(): - sfg = sycl.SyclComposer(SfgContext()) +def test_parallel_for_2_kernels_fail(sfg): + sfg = sycl.SyclComposer(sfg) data_type = "double" dim = 2 f, g = ps.fields(f"f,g:{data_type}[{dim}D]") diff --git a/tests/generator/test_config.py b/tests/generator/test_config.py index 4485dc22e639b185b8b0756ae6d69f92af5e45e8..250c158c633d6f8fc2f11f0d4b3b2cbd13a13128 100644 --- a/tests/generator/test_config.py +++ b/tests/generator/test_config.py @@ -1,4 +1,5 @@ import pytest +from pathlib import Path from pystencilssfg.config import ( SfgConfig, @@ -86,7 +87,7 @@ def test_from_commandline(sample_config_module): cli_args = CommandLineParameters(args) cfg = cli_args.get_config() - assert cfg.output_directory == ".out" + assert cfg.output_directory == Path(".out") assert cfg.extensions.header == "h++" assert cfg.extensions.impl == "c++" @@ -100,7 +101,7 @@ def test_from_commandline(sample_config_module): assert cfg.clang_format.code_style == "llvm" assert cfg.clang_format.skip is True assert ( - cfg.output_directory == "gen_sources" + cfg.output_directory == Path("gen_sources") ) # value from config module overridden by commandline assert cfg.outer_namespace == "myproject" assert cfg.extensions.header == "hpp" diff --git a/tests/generator_scripts/source/BasicDefinitions.py b/tests/generator_scripts/source/BasicDefinitions.py index 7cfe352910b676429b97cb8f3d29bec68b74810a..51ad4d5ce30666dda7d18769a8c4c53b2c943722 100644 --- a/tests/generator_scripts/source/BasicDefinitions.py +++ b/tests/generator_scripts/source/BasicDefinitions.py @@ -4,12 +4,10 @@ from pystencilssfg import SourceFileGenerator, SfgConfig cfg = SfgConfig() cfg.clang_format.skip = True -with SourceFileGenerator(cfg) as sfg: +with SourceFileGenerator(cfg, namespace="awesome") as sfg: sfg.prelude("Expect the unexpected, and you shall never be surprised.") sfg.include("<iostream>") sfg.include("config.h") - sfg.namespace("awesome") - sfg.code("#define PI 3.1415") sfg.code("using namespace std;") diff --git a/tests/generator_scripts/source/Conditionals.py b/tests/generator_scripts/source/Conditionals.py index 9016b73744f78fef504bb4f09b6742a630a6b12d..216f95f0c9bad0fff2d84b6f7e5f33ed92a17a2e 100644 --- a/tests/generator_scripts/source/Conditionals.py +++ b/tests/generator_scripts/source/Conditionals.py @@ -1,9 +1,7 @@ from pystencilssfg import SourceFileGenerator from pystencils.types import PsCustomType -with SourceFileGenerator() as sfg: - sfg.namespace("gen") - +with SourceFileGenerator(namespace="gen") as sfg: sfg.include("<iostream>") sfg.code(r"enum class Noodles { RIGATONI, RAMEN, SPAETZLE, SPAGHETTI };") diff --git a/tests/generator_scripts/source/JacobiMdspan.py b/tests/generator_scripts/source/JacobiMdspan.py index bbe95ac272edbf3b4d9711088c91168cdb525d54..b8f1744f2dfebcc844834d09cf3cd2393a94f591 100644 --- a/tests/generator_scripts/source/JacobiMdspan.py +++ b/tests/generator_scripts/source/JacobiMdspan.py @@ -7,9 +7,7 @@ from pystencilssfg.lang.cpp.std import mdspan mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>") -with SourceFileGenerator() as sfg: - sfg.namespace("gen") - +with SourceFileGenerator(namespace="gen") as sfg: u_src, u_dst, f = fields("u_src, u_dst, f(1) : double[2D]", layout="fzyx") h = sp.Symbol("h") diff --git a/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py index c89fe2455e3ff117596bdd63d538d56d2afcc3b5..9a66b4030701fcde1f873736b68365d1599fafc3 100644 --- a/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py +++ b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py @@ -5,8 +5,7 @@ from pystencilssfg.lang import strip_ptr_ref std.mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>") -with SourceFileGenerator() as sfg: - sfg.namespace("gen") +with SourceFileGenerator(namespace="gen") as sfg: sfg.include("<cassert>") def check_layout(field: ps.Field, mdspan: std.mdspan): diff --git a/tests/generator_scripts/source/MdSpanLbStreaming.py b/tests/generator_scripts/source/MdSpanLbStreaming.py index 60049a86c8142ea08741928077477a9c6f6a4e39..ad8a7583f5bf5f384fb3ed4e1cc5c1485c1b5877 100644 --- a/tests/generator_scripts/source/MdSpanLbStreaming.py +++ b/tests/generator_scripts/source/MdSpanLbStreaming.py @@ -43,8 +43,7 @@ def lbm_stream(sfg: SfgComposer, field_layout: str, layout_policy: str): ) -with SourceFileGenerator() as sfg: - sfg.namespace("gen") +with SourceFileGenerator(namespace="gen") as sfg: sfg.include("<cassert>") sfg.include("<array>") diff --git a/tests/generator_scripts/source/ScaleKernel.py b/tests/generator_scripts/source/ScaleKernel.py index 8bcc75fb7c98e8d46602c7f3f888650d8b8e011c..1d76dc7124265cd4204582e32959f10fde8368c7 100644 --- a/tests/generator_scripts/source/ScaleKernel.py +++ b/tests/generator_scripts/source/ScaleKernel.py @@ -2,7 +2,7 @@ from pystencils import TypedSymbol, fields, kernel from pystencilssfg import SourceFileGenerator -with SourceFileGenerator() as sfg: +with SourceFileGenerator(namespace="gen") as sfg: N = 10 α = TypedSymbol("alpha", "float32") src, dst = fields(f"src, dst: float32[{N}]") @@ -13,7 +13,6 @@ with SourceFileGenerator() as sfg: khandle = sfg.kernels.create(scale) - sfg.namespace("gen") sfg.code(f"constexpr int N = {N};") sfg.klass("Scale")( diff --git a/tests/generator_scripts/source/StlContainers1D.py b/tests/generator_scripts/source/StlContainers1D.py index 3f6ec2c953a6537bef9785d837a2def88439972a..260a6504a7ef3c8fb933bf73852968258a274312 100644 --- a/tests/generator_scripts/source/StlContainers1D.py +++ b/tests/generator_scripts/source/StlContainers1D.py @@ -5,9 +5,7 @@ from pystencilssfg import SourceFileGenerator from pystencilssfg.lang.cpp import std -with SourceFileGenerator() as sfg: - sfg.namespace("StlContainers1D::gen") - +with SourceFileGenerator(namespace="StlContainers1D::gen") as sfg: src, dst = ps.fields("src, dst: double[1D]") asms = [ diff --git a/tests/generator_scripts/source/SyclBuffers.py b/tests/generator_scripts/source/SyclBuffers.py index 36234a84286a561ea4c18e4810aa5e49a5416c26..4668b3cf01355faadc79dfc2c83e3e7071312b76 100644 --- a/tests/generator_scripts/source/SyclBuffers.py +++ b/tests/generator_scripts/source/SyclBuffers.py @@ -4,9 +4,8 @@ from pystencilssfg import SourceFileGenerator import pystencilssfg.extensions.sycl as sycl -with SourceFileGenerator() as sfg: +with SourceFileGenerator(namespace="gen") as sfg: sfg = sycl.SyclComposer(sfg) - sfg.namespace("gen") u_src, u_dst, f = ps.fields("u_src, u_dst, f : double[2D]", layout="fzyx") h = sp.Symbol("h") diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py index 070743ae6ce7e63ce4c825142ed28fb6052d647f..9d51c8fa6ef944a51d9c60219c1460061d6514b1 100644 --- a/tests/ir/test_postprocessing.py +++ b/tests/ir/test_postprocessing.py @@ -2,7 +2,6 @@ import sympy as sp from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type from pystencils.types import PsCustomType -from pystencilssfg import SfgContext, SfgComposer from pystencilssfg.composer import make_sequence from pystencilssfg.lang import IFieldExtraction, AugExpr @@ -11,10 +10,7 @@ from pystencilssfg.ir import SfgStatements, SfgSequence from pystencilssfg.ir.postprocessing import CallTreePostProcessing -def test_live_vars(): - ctx = SfgContext() - sfg = SfgComposer(ctx) - +def test_live_vars(sfg): f, g = fields("f, g(2): double[2D]") x, y = [TypedSymbol(n, "double") for n in "xy"] z = sp.Symbol("z") @@ -42,10 +38,7 @@ def test_live_vars(): assert free_vars == expected -def test_find_sympy_symbols(): - ctx = SfgContext() - sfg = SfgComposer(ctx) - +def test_find_sympy_symbols(sfg): f, g = fields("f, g(2): double[2D]") x, y, z = sp.symbols("x, y, z") @@ -94,7 +87,7 @@ class DemoFieldExtraction(IFieldExtraction): return AugExpr.format("{}.stride({})", self.obj, coordinate) -def test_field_extraction(): +def test_field_extraction(sfg): sx, sy, tx, ty = [ TypedSymbol(n, create_type("int64")) for n in ("sx", "sy", "tx", "ty") ] @@ -104,8 +97,6 @@ def test_field_extraction(): def set_constant(): f.center @= 13.2 - sfg = SfgComposer(SfgContext()) - khandle = sfg.kernels.create(set_constant) extraction = DemoFieldExtraction("f") @@ -129,7 +120,7 @@ def test_field_extraction(): assert stmt.code_string == line -def test_duplicate_field_shapes(): +def test_duplicate_field_shapes(sfg): N, tx, ty = [TypedSymbol(n, create_type("int64")) for n in ("N", "tx", "ty")] f = Field("f", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty)) g = Field("g", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty)) @@ -138,8 +129,6 @@ def test_duplicate_field_shapes(): def set_constant(): f.center @= g.center(0) - sfg = SfgComposer(SfgContext()) - khandle = sfg.kernels.create(set_constant) call_tree = make_sequence(