diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 08422bebbb685bad1220be9412a7ecff4491f8c9..7c08b6765a8b3ca76839d55aee2c7018e68bb939 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -204,8 +204,7 @@ class SfgBasicComposer(SfgIComposer): self.code(*definitions) def namespace(self, namespace: str): - # TODO: Enter into a new namespace context - raise NotImplementedError() + return self._cursor.enter_namespace(namespace) def generate(self, generator: CustomGenerator): """Invoke a custom code generator with the underlying context.""" diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index 24c38e1f26e50f4569da919ebfb272ae8251baac..199c678ba28e449f42b29c56147b9a4fd0d523bb 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -1,5 +1,6 @@ from __future__ import annotations from typing import Sequence, Any, Generator +from contextlib import contextmanager from .config import CodeStyle from .ir import ( @@ -38,12 +39,13 @@ class SfgContext: self._global_namespace = SfgGlobalNamespace() - current_ns: SfgNamespace = self._global_namespace + current_namespace: SfgNamespace if namespace is not None: - for token in namespace.split("::"): - current_ns = SfgNamespace(token, current_ns) + current_namespace = self._global_namespace.get_child_namespace(namespace) + else: + current_namespace = self._global_namespace - self._cursor = SfgCursor(self, current_ns) + self._cursor = SfgCursor(self, current_namespace) @property def argv(self) -> Sequence[str]: @@ -113,8 +115,6 @@ class SfgCursor: else: self._loc[f] = f.elements - # TODO: Enter and exit namespace blocks - @property def current_namespace(self) -> SfgNamespace: return self._cur_namespace @@ -135,3 +135,23 @@ class SfgCursor: f"Cannot write element {elem} to implemenation file since no implementation file is being generated." ) self._loc[impl_file].append(elem) + + def enter_namespace(self, qual_name: str): + namespace = self._cur_namespace.get_child_namespace(qual_name) + + outer_locs = self._loc.copy() + + for f in self._ctx.files: + block = SfgNamespaceBlock(namespace, qual_name) + self._loc[f].append(block) + self._loc[f] = block.elements + + @contextmanager + def ctxmgr(): + try: + yield None + finally: + # Have the cursor step back out of the nested namespace blocks + self._loc = outer_locs + + return ctxmgr() diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py index 471e60bc580a00ed21b219e7f91a9a2be383cfd0..f3f67a02f4da7a44ae324ddd41f5585645cadb58 100644 --- a/src/pystencilssfg/generator.py +++ b/src/pystencilssfg/generator.py @@ -35,7 +35,6 @@ class SourceFileGenerator: def __init__( self, sfg_config: SfgConfig | None = None, - namespace: str | None = None, keep_unknown_argv: bool = False, ): if sfg_config and not isinstance(sfg_config, SfgConfig): @@ -108,15 +107,11 @@ class SourceFileGenerator: outer_namespace: str | _GlobalNamespace = config.get_option("outer_namespace") - match (outer_namespace, namespace): - case [_GlobalNamespace(), None]: - namespace = None - case [_GlobalNamespace(), nspace] if nspace is not None: - namespace = nspace - case [nspace, None] if not isinstance(nspace, _GlobalNamespace): - namespace = nspace - case [outer, inner]: - namespace = f"{outer}::{inner}" + namespace: str | None + if isinstance(outer_namespace, _GlobalNamespace): + namespace = None + else: + namespace = outer_namespace self._context = SfgContext( self._header_file, diff --git a/src/pystencilssfg/ir/entities.py b/src/pystencilssfg/ir/entities.py index 6e6597dbab1e5e5b6de3ea7853091fa37655d6ad..62ae1eb7611065c7b1a12b77d6894143aed341dd 100644 --- a/src/pystencilssfg/ir/entities.py +++ b/src/pystencilssfg/ir/entities.py @@ -74,8 +74,29 @@ class SfgNamespace(SfgCodeEntity): self._entities: dict[str, SfgCodeEntity] = dict() - def get_entity(self, name: str) -> SfgCodeEntity | None: - return self._entities.get(name, None) + def get_entity(self, qual_name: str) -> SfgCodeEntity | None: + """Find an entity with the given qualified name within this namespace. + + If `qual_name` contains any qualifying delimiters ``::``, + each component but the last is interpreted as a namespace. + """ + tokens = qual_name.split("::", 1) + match tokens: + case [entity_name]: + return self._entities.get(entity_name, None) + case [nspace, remaining_qualname]: + sub_nspace = self._entities.get(nspace, None) + if sub_nspace is not None: + if not isinstance(sub_nspace, SfgNamespace): + raise KeyError( + f"Unable to find entity {qual_name} in namespace {self._name}: " + f"Entity {nspace} is not a namespace." + ) + return sub_nspace.get_entity(remaining_qualname) + else: + return None + case _: + assert False, "unreachable code" def add_entity(self, entity: SfgCodeEntity): if entity.name in self._entities: @@ -84,6 +105,24 @@ class SfgNamespace(SfgCodeEntity): ) self._entities[entity.name] = entity + def get_child_namespace(self, qual_name: str): + if not qual_name: + raise ValueError("Anonymous namespaces are not supported") + + # Find the namespace by qualified lookup ... + namespace = self.get_entity(qual_name) + if namespace is not None: + if not type(namespace) is SfgNamespace: + raise ValueError(f"Entity {qual_name} exists, but is not a namespace") + else: + # ... or create it + tokens = qual_name.split("::") + namespace = self + for tok in tokens: + namespace = SfgNamespace(tok, namespace) + + return namespace + class SfgGlobalNamespace(SfgNamespace): """The C++ global namespace.""" diff --git a/tests/generator_scripts/source/BasicDefinitions.py b/tests/generator_scripts/source/BasicDefinitions.py index 51ad4d5ce30666dda7d18769a8c4c53b2c943722..4453066583348e6fad77ae89cd071fccfc5ce19d 100644 --- a/tests/generator_scripts/source/BasicDefinitions.py +++ b/tests/generator_scripts/source/BasicDefinitions.py @@ -4,7 +4,9 @@ from pystencilssfg import SourceFileGenerator, SfgConfig cfg = SfgConfig() cfg.clang_format.skip = True -with SourceFileGenerator(cfg, namespace="awesome") as sfg: +with SourceFileGenerator(cfg) as sfg: + sfg.namespace("awesome") + sfg.prelude("Expect the unexpected, and you shall never be surprised.") sfg.include("<iostream>") sfg.include("config.h") diff --git a/tests/generator_scripts/source/Conditionals.py b/tests/generator_scripts/source/Conditionals.py index 216f95f0c9bad0fff2d84b6f7e5f33ed92a17a2e..9016b73744f78fef504bb4f09b6742a630a6b12d 100644 --- a/tests/generator_scripts/source/Conditionals.py +++ b/tests/generator_scripts/source/Conditionals.py @@ -1,7 +1,9 @@ from pystencilssfg import SourceFileGenerator from pystencils.types import PsCustomType -with SourceFileGenerator(namespace="gen") as sfg: +with SourceFileGenerator() as sfg: + sfg.namespace("gen") + sfg.include("<iostream>") sfg.code(r"enum class Noodles { RIGATONI, RAMEN, SPAETZLE, SPAGHETTI };") diff --git a/tests/generator_scripts/source/JacobiMdspan.py b/tests/generator_scripts/source/JacobiMdspan.py index b8f1744f2dfebcc844834d09cf3cd2393a94f591..2e0741a046d317ed41d091524c0ab6b855318f46 100644 --- a/tests/generator_scripts/source/JacobiMdspan.py +++ b/tests/generator_scripts/source/JacobiMdspan.py @@ -7,13 +7,17 @@ from pystencilssfg.lang.cpp.std import mdspan mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>") -with SourceFileGenerator(namespace="gen") as sfg: +with SourceFileGenerator() as sfg: + sfg.namespace("gen") + u_src, u_dst, f = fields("u_src, u_dst, f(1) : double[2D]", layout="fzyx") h = sp.Symbol("h") @kernel def poisson_jacobi(): - u_dst[0,0] @= (h**2 * f[0, 0] + u_src[1, 0] + u_src[-1, 0] + u_src[0, 1] + u_src[0, -1]) / 4 + u_dst[0, 0] @= ( + h**2 * f[0, 0] + u_src[1, 0] + u_src[-1, 0] + u_src[0, 1] + u_src[0, -1] + ) / 4 poisson_kernel = sfg.kernels.create(poisson_jacobi) @@ -21,5 +25,5 @@ with SourceFileGenerator(namespace="gen") as sfg: sfg.map_field(u_src, mdspan.from_field(u_src, layout_policy="layout_left")), sfg.map_field(u_dst, mdspan.from_field(u_dst, layout_policy="layout_left")), sfg.map_field(f, mdspan.from_field(f, layout_policy="layout_left")), - sfg.call(poisson_kernel) + sfg.call(poisson_kernel), ) diff --git a/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py index 9a66b4030701fcde1f873736b68365d1599fafc3..c89fe2455e3ff117596bdd63d538d56d2afcc3b5 100644 --- a/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py +++ b/tests/generator_scripts/source/MdSpanFixedShapeLayouts.py @@ -5,7 +5,8 @@ from pystencilssfg.lang import strip_ptr_ref std.mdspan.configure(namespace="std::experimental", header="<experimental/mdspan>") -with SourceFileGenerator(namespace="gen") as sfg: +with SourceFileGenerator() as sfg: + sfg.namespace("gen") sfg.include("<cassert>") def check_layout(field: ps.Field, mdspan: std.mdspan): diff --git a/tests/generator_scripts/source/MdSpanLbStreaming.py b/tests/generator_scripts/source/MdSpanLbStreaming.py index ad8a7583f5bf5f384fb3ed4e1cc5c1485c1b5877..60049a86c8142ea08741928077477a9c6f6a4e39 100644 --- a/tests/generator_scripts/source/MdSpanLbStreaming.py +++ b/tests/generator_scripts/source/MdSpanLbStreaming.py @@ -43,7 +43,8 @@ def lbm_stream(sfg: SfgComposer, field_layout: str, layout_policy: str): ) -with SourceFileGenerator(namespace="gen") as sfg: +with SourceFileGenerator() as sfg: + sfg.namespace("gen") sfg.include("<cassert>") sfg.include("<array>") diff --git a/tests/generator_scripts/source/ScaleKernel.py b/tests/generator_scripts/source/ScaleKernel.py index 1d76dc7124265cd4204582e32959f10fde8368c7..2242a3bc34f8edf0fd6ecff6a8c1bd14acf4b0fd 100644 --- a/tests/generator_scripts/source/ScaleKernel.py +++ b/tests/generator_scripts/source/ScaleKernel.py @@ -2,7 +2,9 @@ from pystencils import TypedSymbol, fields, kernel from pystencilssfg import SourceFileGenerator -with SourceFileGenerator(namespace="gen") as sfg: +with SourceFileGenerator() as sfg: + sfg.namespace("gen") + N = 10 α = TypedSymbol("alpha", "float32") src, dst = fields(f"src, dst: float32[{N}]") diff --git a/tests/generator_scripts/source/StlContainers1D.py b/tests/generator_scripts/source/StlContainers1D.py index 260a6504a7ef3c8fb933bf73852968258a274312..91b29110b2aeff7c713c2ea8e89482ecca9bb388 100644 --- a/tests/generator_scripts/source/StlContainers1D.py +++ b/tests/generator_scripts/source/StlContainers1D.py @@ -5,24 +5,23 @@ from pystencilssfg import SourceFileGenerator from pystencilssfg.lang.cpp import std -with SourceFileGenerator(namespace="StlContainers1D::gen") as sfg: +with SourceFileGenerator() as sfg: + sfg.namespace("StlContainers1D::gen") + src, dst = ps.fields("src, dst: double[1D]") - asms = [ - ps.Assignment(dst[0], sp.Rational(1, 3) * (src[-1] + src[0] + src[1])) - ] + asms = [ps.Assignment(dst[0], sp.Rational(1, 3) * (src[-1] + src[0] + src[1]))] kernel = sfg.kernels.create(asms, "average") sfg.function("averageVector")( sfg.map_field(src, std.vector.from_field(src)), sfg.map_field(dst, std.vector.from_field(dst)), - sfg.call(kernel) + sfg.call(kernel), ) sfg.function("averageSpan")( sfg.map_field(src, std.span.from_field(src)), sfg.map_field(dst, std.span.from_field(dst)), - sfg.call(kernel) + sfg.call(kernel), ) - diff --git a/tests/generator_scripts/source/SyclBuffers.py b/tests/generator_scripts/source/SyclBuffers.py index 4668b3cf01355faadc79dfc2c83e3e7071312b76..36234a84286a561ea4c18e4810aa5e49a5416c26 100644 --- a/tests/generator_scripts/source/SyclBuffers.py +++ b/tests/generator_scripts/source/SyclBuffers.py @@ -4,8 +4,9 @@ from pystencilssfg import SourceFileGenerator import pystencilssfg.extensions.sycl as sycl -with SourceFileGenerator(namespace="gen") as sfg: +with SourceFileGenerator() as sfg: sfg = sycl.SyclComposer(sfg) + sfg.namespace("gen") u_src, u_dst, f = ps.fields("u_src, u_dst, f : double[2D]", layout="fzyx") h = sp.Symbol("h") diff --git a/tests/integration/cmake_project/GenTest.py b/tests/integration/cmake_project/GenTest.py index 093374c997f10d0681bb47c2aa909c8c85d65e73..81aec18e250e33cb1969565d80d09ee31206e69e 100644 --- a/tests/integration/cmake_project/GenTest.py +++ b/tests/integration/cmake_project/GenTest.py @@ -1,6 +1,7 @@ from pystencilssfg import SourceFileGenerator -with SourceFileGenerator(namespace="gen") as sfg: +with SourceFileGenerator() as sfg: + sfg.namespace("gen") retval = 42 if sfg.context.project_info is None else sfg.context.project_info sfg.function("getValue", return_type="int")(