From 271404a6da71fa30cd238ab1c59d5668cf9d94e8 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 7 Feb 2025 13:30:43 +0100
Subject: [PATCH] update and fix large parts of the test suite

---
 conftest.py                                   | 33 +++++++++------
 src/pystencilssfg/cli.py                      | 11 +----
 src/pystencilssfg/config.py                   | 20 +++++++++
 src/pystencilssfg/context.py                  |  2 +-
 src/pystencilssfg/emission/file_printer.py    |  9 ++--
 src/pystencilssfg/generator.py                | 41 +++++++++++++++----
 tests/extensions/test_sycl.py                 | 12 +++---
 tests/generator/test_config.py                |  5 ++-
 .../source/BasicDefinitions.py                |  4 +-
 .../generator_scripts/source/Conditionals.py  |  4 +-
 .../generator_scripts/source/JacobiMdspan.py  |  4 +-
 .../source/MdSpanFixedShapeLayouts.py         |  3 +-
 .../source/MdSpanLbStreaming.py               |  3 +-
 tests/generator_scripts/source/ScaleKernel.py |  3 +-
 .../source/StlContainers1D.py                 |  4 +-
 tests/generator_scripts/source/SyclBuffers.py |  3 +-
 tests/ir/test_postprocessing.py               | 19 ++-------
 17 files changed, 104 insertions(+), 76 deletions(-)

diff --git a/conftest.py b/conftest.py
index 661e722..287ed04 100644
--- a/conftest.py
+++ b/conftest.py
@@ -2,21 +2,30 @@ import pytest
 from os import path
 
 
-@pytest.fixture(autouse=True)
-def prepare_doctest_namespace(doctest_namespace):
-    from pystencilssfg import SfgContext, SfgComposer
-    from pystencilssfg import lang
-
-    #   Place a composer object in the environment for doctests
-
-    sfg = SfgComposer(SfgContext())
-    doctest_namespace["sfg"] = sfg
-    doctest_namespace["lang"] = lang
-
-
 DATA_DIR = path.join(path.split(__file__)[0], "tests/data")
 
 
 @pytest.fixture
 def sample_config_module():
     return path.join(DATA_DIR, "project_config.py")
+
+
+@pytest.fixture
+def sfg():
+    from pystencilssfg import SfgContext, SfgComposer
+    from pystencilssfg.ir import SfgSourceFile, SfgSourceFileType
+
+    return SfgComposer(
+        SfgContext(
+            header_file=SfgSourceFile("", SfgSourceFileType.HEADER),
+            impl_file=SfgSourceFile("", SfgSourceFileType.TRANSLATION_UNIT),
+        )
+    )
+
+
+@pytest.fixture(autouse=True)
+def prepare_doctest_namespace(doctest_namespace, sfg):
+    from pystencilssfg import lang
+
+    doctest_namespace["sfg"] = sfg
+    doctest_namespace["lang"] = lang
diff --git a/src/pystencilssfg/cli.py b/src/pystencilssfg/cli.py
index 8cd9f4f..d612fbc 100644
--- a/src/pystencilssfg/cli.py
+++ b/src/pystencilssfg/cli.py
@@ -1,12 +1,11 @@
 import sys
 import os
 from os import path
-from pathlib import Path
 from typing import NoReturn
 
 from argparse import ArgumentParser, BooleanOptionalAction
 
-from .config import CommandLineParameters, SfgConfigException, OutputMode
+from .config import CommandLineParameters, SfgConfigException
 
 
 def add_newline_arg(parser):
@@ -81,13 +80,7 @@ def list_files(args) -> NoReturn:
     _, scriptname = path.split(args.codegen_script)
     basename = path.splitext(scriptname)[0]
 
-    output_dir: Path = config.get_option("output_directory")
-
-    header_ext = config.extensions.get_option("header")
-    output_files = [output_dir / f"{basename}.{header_ext}"]
-    if config.output_mode != OutputMode.HEADER_ONLY:
-        impl_ext = config.extensions.get_option("impl")
-        output_files.append(output_dir / f"{basename}.{impl_ext}")
+    output_files = config._get_output_files(basename)
 
     print(
         args.sep.join(str(of) for of in output_files),
diff --git a/src/pystencilssfg/config.py b/src/pystencilssfg/config.py
index 7bbcfc6..18aaa51 100644
--- a/src/pystencilssfg/config.py
+++ b/src/pystencilssfg/config.py
@@ -173,6 +173,26 @@ class SfgConfig(ConfigBase):
     def _validate_output_directory(self, pth: str | Path) -> Path:
         return Path(pth)
 
+    def _get_output_files(self, basename: str):
+        output_dir: Path = self.get_option("output_directory")
+
+        header_ext = self.extensions.get_option("header")
+        impl_ext = self.extensions.get_option("impl")
+        output_files = [output_dir / f"{basename}.{header_ext}"]
+        output_mode = self.get_option("output_mode")
+
+        if impl_ext is None:
+            match output_mode:
+                case OutputMode.INLINE:
+                    impl_ext = "ipp"
+                case OutputMode.STANDALONE:
+                    impl_ext = "cpp"
+
+        if impl_ext is not None:
+            output_files.append(output_dir / f"{basename}.{impl_ext}")
+
+        return tuple(output_files)
+
 
 class CommandLineParameters:
     @staticmethod
diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py
index f1a3251..032c1e4 100644
--- a/src/pystencilssfg/context.py
+++ b/src/pystencilssfg/context.py
@@ -102,7 +102,7 @@ class SfgCursor:
 
         self._cur_namespace: SfgNamespace = namespace
 
-        self._loc: dict[SfgSourceFile, list[SfgNamespaceElement]]
+        self._loc: dict[SfgSourceFile, list[SfgNamespaceElement]] = dict()
         for f in self._ctx.files:
             if self._cur_namespace is not None:
                 block = SfgNamespaceBlock(self._cur_namespace)
diff --git a/src/pystencilssfg/emission/file_printer.py b/src/pystencilssfg/emission/file_printer.py
index ec43468..f92fba4 100644
--- a/src/pystencilssfg/emission/file_printer.py
+++ b/src/pystencilssfg/emission/file_printer.py
@@ -35,12 +35,12 @@ class SfgFilePrinter:
         code = ""
 
         if file.file_type == SfgSourceFileType.HEADER:
-            code += "#pragma once\n"
+            code += "#pragma once\n\n"
 
         if file.prelude:
             comment = "/**\n"
             comment += indent(file.prelude, " * ")
-            comment += "\n */\n\n"
+            comment += " */\n\n"
 
             code += comment
 
@@ -54,6 +54,7 @@ class SfgFilePrinter:
         #   Here begins the actual code
         code += "\n\n".join(self.visit(elem) for elem in file.elements)
         code += "\n"
+
         return code
 
     def visit(
@@ -73,6 +74,8 @@ class SfgFilePrinter:
                 return self.visit_decl(entity, inclass)
             case SfgEntityDef(entity):
                 return self.visit_defin(entity, inclass)
+            case SfgClassBody():
+                return self.visit_defin(elem, inclass)
             case _:
                 assert False, "illegal code element"
 
@@ -173,7 +176,7 @@ class SfgFilePrinter:
         code = ""
         if func.inline:
             code += "inline "
-        code += func.return_type.c_string()
+        code += func.return_type.c_string() + " "
         params_str = ", ".join(
             f"{param.dtype.c_string()} {param.name}" for param in func.parameters
         )
diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py
index b055ad5..1191ebe 100644
--- a/src/pystencilssfg/generator.py
+++ b/src/pystencilssfg/generator.py
@@ -1,10 +1,16 @@
 from pathlib import Path
 
-from .config import SfgConfig, CommandLineParameters, OutputMode, GLOBAL_NAMESPACE
+from .config import (
+    SfgConfig,
+    CommandLineParameters,
+    OutputMode,
+    _GlobalNamespace,
+)
 from .context import SfgContext
 from .composer import SfgComposer
 from .emission import SfgCodeEmitter
 from .exceptions import SfgException
+from .lang import HeaderFile
 
 
 class SourceFileGenerator:
@@ -26,7 +32,10 @@ class SourceFileGenerator:
     """
 
     def __init__(
-        self, sfg_config: SfgConfig | None = None, keep_unknown_argv: bool = False
+        self,
+        sfg_config: SfgConfig | None = None,
+        namespace: str | None = None,
+        keep_unknown_argv: bool = False,
     ):
         if sfg_config and not isinstance(sfg_config, SfgConfig):
             raise TypeError("sfg_config is not an SfgConfiguration.")
@@ -68,13 +77,13 @@ class SourceFileGenerator:
 
         self._output_mode: OutputMode = config.get_option("output_mode")
         self._output_dir: Path = config.get_option("output_directory")
-        self._header_ext: str = config.extensions.get_option("header")
-        self._impl_ext: str = config.extensions.get_option("impl")
+
+        output_files = config._get_output_files(basename)
 
         from .ir import SfgSourceFile, SfgSourceFileType
 
         self._header_file = SfgSourceFile(
-            f"{basename}.{self._header_ext}", SfgSourceFileType.HEADER
+            output_files[0].name, SfgSourceFileType.HEADER
         )
         self._impl_file: SfgSourceFile | None
 
@@ -83,17 +92,29 @@ class SourceFileGenerator:
                 self._impl_file = None
             case OutputMode.STANDALONE:
                 self._impl_file = SfgSourceFile(
-                    f"{basename}.{self._impl_ext}", SfgSourceFileType.TRANSLATION_UNIT
+                    output_files[1].name, SfgSourceFileType.TRANSLATION_UNIT
+                )
+                self._impl_file.includes.append(
+                    HeaderFile.parse(self._header_file.name)
                 )
             case OutputMode.INLINE:
                 self._impl_file = SfgSourceFile(
-                    f"{basename}.{self._impl_ext}", SfgSourceFileType.HEADER
+                    output_files[1].name, SfgSourceFileType.HEADER
                 )
 
+        outer_namespace: str | _GlobalNamespace = config.get_option("outer_namespace")
+        match (outer_namespace, namespace):
+            case [_GlobalNamespace(), None]:
+                namespace = None
+            case [_GlobalNamespace(), nspace] | [nspace, None]:
+                namespace = nspace
+            case [outer, inner]:
+                namespace = f"{outer}::{inner}"
+
         self._context = SfgContext(
             self._header_file,
             self._impl_file,
-            None if config.outer_namespace is GLOBAL_NAMESPACE else config.outer_namespace,  # type: ignore
+            namespace,
             config.codestyle,
             argv=script_args,
             project_info=cli_params.get_project_info(),
@@ -119,6 +140,10 @@ class SourceFileGenerator:
 
     def __exit__(self, exc_type, exc_value, traceback):
         if exc_type is None:
+            if self._output_mode == OutputMode.INLINE:
+                assert self._impl_file is not None
+                self._header_file.elements.append(f'#include "{self._impl_file.name}"')
+
             #    Collect header files for inclusion
             # from .ir import collect_includes
 
diff --git a/tests/extensions/test_sycl.py b/tests/extensions/test_sycl.py
index db99278..0e067c8 100644
--- a/tests/extensions/test_sycl.py
+++ b/tests/extensions/test_sycl.py
@@ -5,8 +5,8 @@ import pystencils as ps
 from pystencilssfg import SfgContext
 
 
-def test_parallel_for_1_kernels():
-    sfg = sycl.SyclComposer(SfgContext())
+def test_parallel_for_1_kernels(sfg):
+    sfg = sycl.SyclComposer(sfg)
     data_type = "double"
     dim = 2
     f, g, h, i = ps.fields(f"f,g,h,i:{data_type}[{dim}D]")
@@ -24,8 +24,8 @@ def test_parallel_for_1_kernels():
     )
 
 
-def test_parallel_for_2_kernels():
-    sfg = sycl.SyclComposer(SfgContext())
+def test_parallel_for_2_kernels(sfg):
+    sfg = sycl.SyclComposer(sfg)
     data_type = "double"
     dim = 2
     f, g, h, i = ps.fields(f"f,g,h,i:{data_type}[{dim}D]")
@@ -43,8 +43,8 @@ def test_parallel_for_2_kernels():
     )
 
 
-def test_parallel_for_2_kernels_fail():
-    sfg = sycl.SyclComposer(SfgContext())
+def test_parallel_for_2_kernels_fail(sfg):
+    sfg = sycl.SyclComposer(sfg)
     data_type = "double"
     dim = 2
     f, g = ps.fields(f"f,g:{data_type}[{dim}D]")
diff --git a/tests/generator/test_config.py b/tests/generator/test_config.py
index 4485dc2..250c158 100644
--- a/tests/generator/test_config.py
+++ b/tests/generator/test_config.py
@@ -1,4 +1,5 @@
 import pytest
+from pathlib import Path
 
 from pystencilssfg.config import (
     SfgConfig,
@@ -86,7 +87,7 @@ def test_from_commandline(sample_config_module):
     cli_args = CommandLineParameters(args)
     cfg = cli_args.get_config()
 
-    assert cfg.output_directory == ".out"
+    assert cfg.output_directory == Path(".out")
     assert cfg.extensions.header == "h++"
     assert cfg.extensions.impl == "c++"
 
@@ -100,7 +101,7 @@ def test_from_commandline(sample_config_module):
     assert cfg.clang_format.code_style == "llvm"
     assert cfg.clang_format.skip is True
     assert (
-        cfg.output_directory == "gen_sources"
+        cfg.output_directory == Path("gen_sources")
     )  # value from config module overridden by commandline
     assert cfg.outer_namespace == "myproject"
     assert cfg.extensions.header == "hpp"
diff --git a/tests/generator_scripts/source/BasicDefinitions.py b/tests/generator_scripts/source/BasicDefinitions.py
index 7cfe352..51ad4d5 100644
--- a/tests/generator_scripts/source/BasicDefinitions.py
+++ b/tests/generator_scripts/source/BasicDefinitions.py
@@ -4,12 +4,10 @@ from pystencilssfg import SourceFileGenerator, SfgConfig
 cfg = SfgConfig()
 cfg.clang_format.skip = True
 
-with SourceFileGenerator(cfg) as sfg:
+with SourceFileGenerator(cfg, namespace="awesome") as sfg:
     sfg.prelude("Expect the unexpected, and you shall never be surprised.")
     sfg.include("<iostream>")
     sfg.include("config.h")
 
-    sfg.namespace("awesome")
-
     sfg.code("#define PI 3.1415")
     sfg.code("using namespace std;")
diff --git a/tests/generator_scripts/source/Conditionals.py b/tests/generator_scripts/source/Conditionals.py
index 9016b73..216f95f 100644
--- a/tests/generator_scripts/source/Conditionals.py
+++ b/tests/generator_scripts/source/Conditionals.py
@@ -1,9 +1,7 @@
 from pystencilssfg import SourceFileGenerator
 from pystencils.types import PsCustomType
 
-with SourceFileGenerator() as sfg:
-    sfg.namespace("gen")
-
+with SourceFileGenerator(namespace="gen") as sfg:
     sfg.include("<iostream>")
     sfg.code(r"enum class Noodles { RIGATONI, RAMEN, SPAETZLE, SPAGHETTI };")
 
diff --git a/tests/generator_scripts/source/JacobiMdspan.py b/tests/generator_scripts/source/JacobiMdspan.py
index bbe95ac..b8f1744 100644
--- a/tests/generator_scripts/source/JacobiMdspan.py
+++ b/tests/generator_scripts/source/JacobiMdspan.py
@@ -7,9 +7,7 @@ from pystencilssfg.lang.cpp.std import mdspan
 
 mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>")
 
-with SourceFileGenerator() as sfg:
-    sfg.namespace("gen")
-
+with SourceFileGenerator(namespace="gen") as sfg:
     u_src, u_dst, f = fields("u_src, u_dst, f(1) : double[2D]", layout="fzyx")
     h = sp.Symbol("h")
 
diff --git a/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py
index c89fe24..9a66b40 100644
--- a/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py
+++ b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py
@@ -5,8 +5,7 @@ from pystencilssfg.lang import strip_ptr_ref
 
 std.mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>")
 
-with SourceFileGenerator() as sfg:
-    sfg.namespace("gen")
+with SourceFileGenerator(namespace="gen") as sfg:
     sfg.include("<cassert>")
 
     def check_layout(field: ps.Field, mdspan: std.mdspan):
diff --git a/tests/generator_scripts/source/MdSpanLbStreaming.py b/tests/generator_scripts/source/MdSpanLbStreaming.py
index 60049a8..ad8a758 100644
--- a/tests/generator_scripts/source/MdSpanLbStreaming.py
+++ b/tests/generator_scripts/source/MdSpanLbStreaming.py
@@ -43,8 +43,7 @@ def lbm_stream(sfg: SfgComposer, field_layout: str, layout_policy: str):
     )
 
 
-with SourceFileGenerator() as sfg:
-    sfg.namespace("gen")
+with SourceFileGenerator(namespace="gen") as sfg:
     sfg.include("<cassert>")
     sfg.include("<array>")
 
diff --git a/tests/generator_scripts/source/ScaleKernel.py b/tests/generator_scripts/source/ScaleKernel.py
index 8bcc75f..1d76dc7 100644
--- a/tests/generator_scripts/source/ScaleKernel.py
+++ b/tests/generator_scripts/source/ScaleKernel.py
@@ -2,7 +2,7 @@ from pystencils import TypedSymbol, fields, kernel
 
 from pystencilssfg import SourceFileGenerator
 
-with SourceFileGenerator() as sfg:
+with SourceFileGenerator(namespace="gen") as sfg:
     N = 10
     α = TypedSymbol("alpha", "float32")
     src, dst = fields(f"src, dst: float32[{N}]")
@@ -13,7 +13,6 @@ with SourceFileGenerator() as sfg:
 
     khandle = sfg.kernels.create(scale)
 
-    sfg.namespace("gen")
     sfg.code(f"constexpr int N = {N};")
 
     sfg.klass("Scale")(
diff --git a/tests/generator_scripts/source/StlContainers1D.py b/tests/generator_scripts/source/StlContainers1D.py
index 3f6ec2c..260a650 100644
--- a/tests/generator_scripts/source/StlContainers1D.py
+++ b/tests/generator_scripts/source/StlContainers1D.py
@@ -5,9 +5,7 @@ from pystencilssfg import SourceFileGenerator
 from pystencilssfg.lang.cpp import std
 
 
-with SourceFileGenerator() as sfg:
-    sfg.namespace("StlContainers1D::gen")
-
+with SourceFileGenerator(namespace="StlContainers1D::gen") as sfg:
     src, dst = ps.fields("src, dst: double[1D]")
 
     asms = [
diff --git a/tests/generator_scripts/source/SyclBuffers.py b/tests/generator_scripts/source/SyclBuffers.py
index 36234a8..4668b3c 100644
--- a/tests/generator_scripts/source/SyclBuffers.py
+++ b/tests/generator_scripts/source/SyclBuffers.py
@@ -4,9 +4,8 @@ from pystencilssfg import SourceFileGenerator
 import pystencilssfg.extensions.sycl as sycl
 
 
-with SourceFileGenerator() as sfg:
+with SourceFileGenerator(namespace="gen") as sfg:
     sfg = sycl.SyclComposer(sfg)
-    sfg.namespace("gen")
 
     u_src, u_dst, f = ps.fields("u_src, u_dst, f : double[2D]", layout="fzyx")
     h = sp.Symbol("h")
diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py
index 070743a..9d51c8f 100644
--- a/tests/ir/test_postprocessing.py
+++ b/tests/ir/test_postprocessing.py
@@ -2,7 +2,6 @@ import sympy as sp
 from pystencils import fields, kernel, TypedSymbol, Field, FieldType, create_type
 from pystencils.types import PsCustomType
 
-from pystencilssfg import SfgContext, SfgComposer
 from pystencilssfg.composer import make_sequence
 
 from pystencilssfg.lang import IFieldExtraction, AugExpr
@@ -11,10 +10,7 @@ from pystencilssfg.ir import SfgStatements, SfgSequence
 from pystencilssfg.ir.postprocessing import CallTreePostProcessing
 
 
-def test_live_vars():
-    ctx = SfgContext()
-    sfg = SfgComposer(ctx)
-
+def test_live_vars(sfg):
     f, g = fields("f, g(2): double[2D]")
     x, y = [TypedSymbol(n, "double") for n in "xy"]
     z = sp.Symbol("z")
@@ -42,10 +38,7 @@ def test_live_vars():
     assert free_vars == expected
 
 
-def test_find_sympy_symbols():
-    ctx = SfgContext()
-    sfg = SfgComposer(ctx)
-
+def test_find_sympy_symbols(sfg):
     f, g = fields("f, g(2): double[2D]")
     x, y, z = sp.symbols("x, y, z")
 
@@ -94,7 +87,7 @@ class DemoFieldExtraction(IFieldExtraction):
         return AugExpr.format("{}.stride({})", self.obj, coordinate)
 
 
-def test_field_extraction():
+def test_field_extraction(sfg):
     sx, sy, tx, ty = [
         TypedSymbol(n, create_type("int64")) for n in ("sx", "sy", "tx", "ty")
     ]
@@ -104,8 +97,6 @@ def test_field_extraction():
     def set_constant():
         f.center @= 13.2
 
-    sfg = SfgComposer(SfgContext())
-
     khandle = sfg.kernels.create(set_constant)
 
     extraction = DemoFieldExtraction("f")
@@ -129,7 +120,7 @@ def test_field_extraction():
         assert stmt.code_string == line
 
 
-def test_duplicate_field_shapes():
+def test_duplicate_field_shapes(sfg):
     N, tx, ty = [TypedSymbol(n, create_type("int64")) for n in ("N", "tx", "ty")]
     f = Field("f", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty))
     g = Field("g", FieldType.GENERIC, "double", (1, 0), (N, N), (tx, ty))
@@ -138,8 +129,6 @@ def test_duplicate_field_shapes():
     def set_constant():
         f.center @= g.center(0)
 
-    sfg = SfgComposer(SfgContext())
-
     khandle = sfg.kernels.create(set_constant)
 
     call_tree = make_sequence(
-- 
GitLab