diff --git a/src/pystencilssfg/ir/analysis.py b/src/pystencilssfg/ir/analysis.py index 5b6d2d693af8c5eb27c4f2c8fc7b7f38a69108b7..4e43eb92c7a2457c5e60f0743c9ac5d3809f87bc 100644 --- a/src/pystencilssfg/ir/analysis.py +++ b/src/pystencilssfg/ir/analysis.py @@ -1,7 +1,5 @@ from __future__ import annotations -from functools import reduce - from ..lang import HeaderFile, includes from .syntax import ( SfgSourceFile, @@ -35,9 +33,7 @@ def collect_includes(file: SfgSourceFile) -> set[HeaderFile]: | SfgMethod(_, _, parameters) | SfgConstructor(_, parameters, _, _) ): - incls: set[HeaderFile] = reduce( - lambda accu, p: accu | includes(p), parameters, set() - ) + incls: set[HeaderFile] = set().union(*(includes(p) for p in parameters)) if isinstance(entity, (SfgFunction, SfgMethod)): incls |= includes(entity.return_type) return incls @@ -61,10 +57,8 @@ def collect_includes(file: SfgSourceFile) -> set[HeaderFile]: return set() case SfgCallTreeNode(): - return reduce( - lambda accu, child: accu | walk_syntax(child), - obj.children, - obj.required_includes, + return obj.required_includes.union( + *(walk_syntax(child) for child in obj.children), ) case SfgEntityDecl(entity): @@ -92,16 +86,12 @@ def collect_includes(file: SfgSourceFile) -> set[HeaderFile]: assert False, "unexpected entity" case SfgNamespaceBlock(_, elements) | SfgVisibilityBlock(_, elements): - return reduce( - lambda accu, elem: accu | walk_syntax(elem), elements, set() - ) # type: ignore + return set().union(*(walk_syntax(elem) for elem in elements)) # type: ignore case SfgClassBody(_, vblocks): - return reduce( - lambda accu, vblock: accu | walk_syntax(vblock), vblocks, set() - ) + return set().union(*(walk_syntax(vb) for vb in vblocks)) case _: assert False, "unexpected syntax element" - return reduce(lambda accu, elem: accu | walk_syntax(elem), file.elements, set()) + return set().union(*(walk_syntax(elem) for elem in file.elements)) diff --git a/tests/extensions/test_sycl.py b/tests/extensions/test_sycl.py index 0e067c8d593f4a038f2b2cdb2b21c40ab47e8bb2..71effb60a3f10a3d5f505b01b1ef256cea9ad45e 100644 --- a/tests/extensions/test_sycl.py +++ b/tests/extensions/test_sycl.py @@ -1,8 +1,6 @@ import pytest -from pystencilssfg import SourceFileGenerator import pystencilssfg.extensions.sycl as sycl import pystencils as ps -from pystencilssfg import SfgContext def test_parallel_for_1_kernels(sfg): diff --git a/tests/generator_scripts/index.yaml b/tests/generator_scripts/index.yaml index b7237999332beef0de277f35e01e0285ff6f1389..0e08e228702bf5c762120e9bc66117a5892bf66d 100644 --- a/tests/generator_scripts/index.yaml +++ b/tests/generator_scripts/index.yaml @@ -17,6 +17,17 @@ TestIllegalArgs: extra-args: [--sfg-file-extensionss, ".c++,.h++"] expect-failure: true +TestIncludeSorting: + sfg-args: + output-mode: header-only + expect-code: + hpp: + - regex: >- + #include\s\<memory>\s* + #include\s<vector>\s* + #include\s<array> + strip-whitespace: true + # Basic Composer Functionality BasicDefinitions: diff --git a/tests/generator_scripts/source/TestIncludeSorting.py b/tests/generator_scripts/source/TestIncludeSorting.py new file mode 100644 index 0000000000000000000000000000000000000000..8a584f6b0f4a836b50d862b2d3a161ce98caa517 --- /dev/null +++ b/tests/generator_scripts/source/TestIncludeSorting.py @@ -0,0 +1,23 @@ +from pystencilssfg import SourceFileGenerator, SfgConfig +from pystencilssfg.lang import HeaderFile + + +def sortkey(h: HeaderFile): + try: + return [ + "memory", + "vector", + "array" + ].index(h.filepath) + except ValueError: + return 100 + + +cfg = SfgConfig() +cfg.codestyle.includes_sorting_key = sortkey + + +with SourceFileGenerator(cfg) as sfg: + sfg.include("<array>") + sfg.include("<memory>") + sfg.include("<vector>")