Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (6)
Showing
with 278 additions and 197 deletions
......@@ -9,6 +9,7 @@
# dev environment
**/.venv
**/venv
# build artifacts
dist
......@@ -17,9 +18,6 @@ dist
*.egg-info
# tests and coverage
.coverage
.coverage*
htmlcov
coverage.xml
# mkdocs
site
\ No newline at end of file
......@@ -37,9 +37,9 @@ testsuite:
- docker
before_script:
- pip install "git+https://i10git.cs.fau.de/pycodegen/pystencils.git@v2.0-dev"
- pip install -e .
- pip install -e .[tests]
script:
- pytest -v --cov=src/pystencilssfg --cov-report=term
- pytest -v --cov=src/pystencilssfg --cov-report=term --cov-config=pyproject.toml
- coverage html
- coverage xml
coverage: '/TOTAL.*\s+(\d+%)$/'
......
***************************************
Composer API (`pystencilssfg.composer`)
***************************************
*****************************************
Composer API (``pystencilssfg.composer``)
*****************************************
.. autoclass:: pystencilssfg.composer.SfgComposer
:members:
......
......@@ -9,8 +9,20 @@ Expressions
.. automodule:: pystencilssfg.lang.expressions
:members:
C++ Standard Library (`pystencilssfg.lang.cpp`)
-----------------------------------------------
Header Files
------------
.. automodule:: pystencilssfg.lang.headers
:members:
Data Types
----------
.. automodule:: pystencilssfg.lang.types
:members:
C++ Standard Library (``pystencilssfg.lang.cpp``)
-------------------------------------------------
Quick Access
^^^^^^^^^^^^
......
......@@ -65,12 +65,15 @@ html_theme_options = {
intersphinx_mapping = {
"python": ("https://docs.python.org/3.8", None),
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
"matplotlib": ("https://matplotlib.org/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"sympy": ("https://docs.sympy.org/latest/", None),
"pystencils": ("https://da15siwa.pages.i10git.cs.fau.de/dev-docs/pystencils-nbackend/", None),
}
# References
# Treat `single-quoted` code blocks as references to any
default_role = "any"
# Autodoc options
......
......@@ -68,7 +68,7 @@ import sympy as sp
from pystencils import fields, kernel
from pystencilssfg import SourceFileGenerator
from pystencilssfg.lang.cpp import mdspan_ref
from pystencilssfg.lang.cpp import std
with SourceFileGenerator() as sfg:
u_src, u_dst, f = fields("u_src, u_dst, f(1) : double[2D]", layout="fzyx")
......@@ -81,9 +81,9 @@ with SourceFileGenerator() as sfg:
poisson_kernel = sfg.kernels.create(poisson_jacobi)
sfg.function("jacobi_smooth")(
sfg.map_field(u_src, mdspan_ref(u_src)),
sfg.map_field(u_dst, mdspan_ref(u_dst)),
sfg.map_field(f, mdspan_ref(f)),
sfg.map_field(u_src, std.mdspan.from_field(u_src)),
sfg.map_field(u_dst, std.mdspan.from_field(u_dst)),
sfg.map_field(f, std.mdspan.from_field(f)),
sfg.call(poisson_kernel)
)
```
......@@ -102,11 +102,11 @@ python poisson_smoother.py
```
During execution, *pystencils-sfg* assembles the above constructs into an internal representation of the C++ files.
It then takes the name of your Python script, replaces `.py` with `.cpp` and `.h`,
It then takes the name of your Python script, replaces `.py` with `.cpp` and `.hpp`,
and exports the constructed code to the files
`poisson_smoother.cpp` and `poisson_smoother.h` into the current directory, ready to be `#include`d.
`poisson_smoother.cpp` and `poisson_smoother.hpp` into the current directory, ready to be `#include`d.
````{dropdown} poisson_smoother.h
````{dropdown} poisson_smoother.hpp
```C++
#pragma once
......@@ -129,7 +129,7 @@ void jacobi_smooth(
````{dropdown} poisson_smoother.cpp
```C++
#include "poisson_smoother.h"
#include "poisson_smoother.hpp"
#include <math.h>
......
......@@ -21,8 +21,8 @@ with SourceFileGenerator() as sfg:
sfg.include("<span>")
sfg.function("scale_kernel")(
sfg.map_field(src, std.vector(src)),
sfg.map_field(dst, std.span(dst)),
sfg.map_field(src, std.vector.from_field(src)),
sfg.map_field(dst, std.span.from_field(dst)),
sfg.call(scale_kernel)
)
# end
from pystencils import Target, CreateKernelConfig, no_jit
from lbmpy import create_lb_update_rule, LBMOptimisation
from pystencilssfg import SourceFileGenerator, SfgConfig, OutputMode
from pystencilssfg.lang.cpp.sycl_accessor import sycl_accessor_ref
import pystencilssfg.extensions.sycl as sycl
from itertools import chain
sfg_config = SfgConfig(
output_directory="out/test_sycl_buffer",
outer_namespace="gen_code",
output_mode=OutputMode.INLINE,
)
with SourceFileGenerator(sfg_config) as sfg:
sfg = sycl.SyclComposer(sfg)
gen_config = CreateKernelConfig(target=Target.SYCL, jit=no_jit)
opt = LBMOptimisation(field_layout="fzyx")
update = create_lb_update_rule(lbm_optimisation=opt)
kernel = sfg.kernels.create(update, "lbm_update", gen_config)
cgh = sfg.sycl_handler("handler")
rang = sfg.sycl_range(update.method.dim, "range")
mappings = [
sfg.map_field(field, sycl_accessor_ref(field))
for field in chain(update.free_fields, update.bound_fields)
]
sfg.function("lb_update")(
cgh.parallel_for(rang)(
*mappings,
sfg.call(kernel),
),
)
......@@ -23,10 +23,13 @@ requires = [
build-backend = "setuptools.build_meta"
[project.optional-dependencies]
testing = [
tests = [
"flake8>=6.1.0",
"mypy>=1.7.0",
"black"
"black",
"pyyaml",
"requests",
"fasteners",
]
docs = [
"sphinx",
......@@ -54,3 +57,10 @@ omit = [
"src/pystencilssfg/_version.py",
"integration/*"
]
[tool.coverage.report]
exclude_also = [
"\\.\\.\\.",
"if TYPE_CHECKING:",
"@(abc\\.)?abstractmethod",
]
......@@ -5,8 +5,9 @@ python_files = "test_*.py"
# during test collection
addopts =
--doctest-modules
--ignore=tests/generator_scripts/scripts
--ignore=tests/generator_scripts/config
--ignore=tests/generator_scripts/source
--ignore=tests/generator_scripts/deps
--ignore=tests/generator_scripts/expected
--ignore=tests/data
doctest_optionflags = NORMALIZE_WHITESPACE IGNORE_EXCEPTION_DETAIL
......@@ -30,7 +30,6 @@ from ..ir import (
SfgSwitch,
)
from ..ir.postprocessing import (
SfgDeferredParamMapping,
SfgDeferredParamSetter,
SfgDeferredFieldMapping,
SfgDeferredVectorMapping,
......@@ -52,11 +51,14 @@ from ..lang import (
_ExprLike,
asvar,
depends,
HeaderFile,
includes,
SfgVar,
AugExpr,
SrcField,
IFieldExtraction,
SrcVector,
void,
)
from ..exceptions import SfgException
......@@ -217,7 +219,7 @@ class SfgBasicComposer(SfgIComposer):
#include <vector>
#include "custom.h"
"""
self._ctx.add_include(SfgHeaderInclude.parse(header_file, private))
self._ctx.add_include(SfgHeaderInclude(HeaderFile.parse(header_file), private))
def numpy_struct(
self, name: str, dtype: np.dtype, add_constructor: bool = True
......@@ -256,7 +258,7 @@ class SfgBasicComposer(SfgIComposer):
func = SfgFunction(name, tree)
self._ctx.add_function(func)
def function(self, name: str):
def function(self, name: str, return_type: UserTypeSpec = void):
"""Add a function.
The syntax of this function adder uses a chain of two calls to mimic C++ syntax:
......@@ -274,7 +276,7 @@ class SfgBasicComposer(SfgIComposer):
def sequencer(*args: SequencerArg):
tree = make_sequence(*args)
func = SfgFunction(name, tree)
func = SfgFunction(name, tree, return_type=create_type(return_type))
self._ctx.add_function(func)
return sequencer
......@@ -284,7 +286,7 @@ class SfgBasicComposer(SfgIComposer):
When using `call`, the given kernel will simply be called as a function.
To invoke a GPU kernel on a specified launch grid, use `cuda_invoke`
or the interfaces of `pystencilssfg.extensions.sycl` instead.
or the interfaces of ``pystencilssfg.extensions.sycl`` instead.
Args:
kernel_handle: Handle to a kernel previously added to some kernel namespace.
......@@ -298,6 +300,7 @@ class SfgBasicComposer(SfgIComposer):
threads_per_block: ExprLike,
stream: ExprLike | None,
):
"""Dispatch a CUDA kernel to the device."""
num_blocks_str = str(num_blocks)
tpb_str = str(threads_per_block)
stream_str = str(stream) if stream is not None else None
......@@ -318,11 +321,9 @@ class SfgBasicComposer(SfgIComposer):
"""Use inside a function body to add parameters to the function."""
return SfgFunctionParams([x.as_variable() for x in args])
def require(self, *includes: str | SfgHeaderInclude) -> SfgRequireIncludes:
def require(self, *incls: str | HeaderFile) -> SfgRequireIncludes:
"""Use inside a function body to require the inclusion of headers."""
return SfgRequireIncludes(
list(SfgHeaderInclude.parse(incl) for incl in includes)
)
return SfgRequireIncludes((HeaderFile.parse(incl) for incl in incls))
def cpptype(
self,
......@@ -385,10 +386,12 @@ class SfgBasicComposer(SfgIComposer):
def parse_args(*args: ExprLike):
args_str = ", ".join(str(arg) for arg in args)
deps: set[SfgVar] = reduce(set.union, (depends(arg) for arg in args), set())
incls: set[HeaderFile] = reduce(set.union, (includes(arg) for arg in args))
return SfgStatements(
f"{lhs_var.dtype.c_string()} {lhs_var.name} {{ {args_str} }};",
(lhs_var,),
deps,
incls,
)
return parse_args
......@@ -443,9 +446,14 @@ class SfgBasicComposer(SfgIComposer):
"""
return SfgBranchBuilder()
def switch(self, switch_arg: ExprLike) -> SfgSwitchBuilder:
"""Use inside a function to construct a switch-case statement."""
return SfgSwitchBuilder(switch_arg)
def switch(self, switch_arg: ExprLike, autobreak: bool = True) -> SfgSwitchBuilder:
"""Use inside a function to construct a switch-case statement.
Args:
switch_arg: Argument to the `switch()` statement
autobreak: Whether to automatically print a `break;` at the end of each case block
"""
return SfgSwitchBuilder(switch_arg, autobreak=autobreak)
def map_field(
self,
......@@ -466,30 +474,14 @@ class SfgBasicComposer(SfgIComposer):
)
def set_param(self, param: VarLike | sp.Symbol, expr: ExprLike):
deps = depends(expr)
var: SfgVar | sp.Symbol = asvar(param) if isinstance(param, _VarLike) else param
return SfgDeferredParamSetter(var, deps, str(expr))
def map_param(
self,
param: VarLike | sp.Symbol,
depends: VarLike | Sequence[VarLike],
mapping: str,
):
from warnings import warn
warn(
"The `map_param` method of `SfgBasicComposer` is deprecated and will be removed "
"in a future version. Use `sfg.set_param` instead.",
FutureWarning,
)
"""Set a kernel parameter to an expression.
if isinstance(depends, _VarLike):
depends = [depends]
lhs_var: SfgVar | sp.Symbol = (
asvar(param) if isinstance(param, _VarLike) else param
)
return SfgDeferredParamMapping(lhs_var, set(asvar(v) for v in depends), mapping)
Code setting the parameter will only be generated if the parameter
is actually alive (i.e. required by some kernel, and not yet set) at
the point this method is called.
"""
var: SfgVar | sp.Symbol = asvar(param) if isinstance(param, _VarLike) else param
return SfgDeferredParamSetter(var, expr)
def map_vector(self, lhs_components: Sequence[VarLike | sp.Symbol], rhs: SrcVector):
"""Extracts scalar numerical values from a vector data type.
......@@ -505,7 +497,7 @@ class SfgBasicComposer(SfgIComposer):
def make_statements(arg: ExprLike) -> SfgStatements:
return SfgStatements(str(arg), (), depends(arg))
return SfgStatements(str(arg), (), depends(arg), includes(arg))
def make_sequence(*args: SequencerArg) -> SfgSequence:
......@@ -613,16 +605,19 @@ class SfgBranchBuilder(SfgNodeBuilder):
class SfgSwitchBuilder(SfgNodeBuilder):
"""Builder for C++ switches."""
def __init__(self, switch_arg: ExprLike):
def __init__(self, switch_arg: ExprLike, autobreak: bool = True):
self._switch_arg = switch_arg
self._cases: dict[str, SfgSequence] = dict()
self._default: SfgSequence | None = None
self._autobreak = autobreak
def case(self, label: str):
if label in self._cases:
raise SfgException(f"Duplicate case: {label}")
def sequencer(*args: SequencerArg):
if self._autobreak:
args += ("break;",)
tree = make_sequence(*args)
self._cases[label] = tree
return self
......
......@@ -29,7 +29,8 @@ def invoke_clang_format(code: str, options: ClangFormatOptions) -> str:
binary = options.get_option("binary")
force = options.get_option("force")
args = [binary, f"--style={options.code_style}"]
style = options.get_option("code_style")
args = [binary, f"--style={style}"]
if not shutil.which(binary):
if force:
......
......@@ -2,7 +2,6 @@ from typing import Sequence
from os import path, makedirs
from ..context import SfgContext
from .prepare import prepare_context
from .printers import SfgHeaderPrinter, SfgImplPrinter
from .clang_format import invoke_clang_format
from ..config import ClangFormatOptions
......@@ -40,8 +39,6 @@ class HeaderImplPairEmitter(AbstractEmitter):
def write_files(self, ctx: SfgContext):
"""Write the code represented by the given [SfgContext][pystencilssfg.SfgContext] to the files
specified by the output specification."""
ctx = prepare_context(ctx)
header_printer = SfgHeaderPrinter(ctx, self._ospec, self._inline_impl)
impl_printer = SfgImplPrinter(ctx, self._ospec, self._inline_impl)
......
......@@ -2,7 +2,6 @@ from typing import Sequence
from os import path, makedirs
from ..context import SfgContext
from .prepare import prepare_context
from .printers import SfgHeaderPrinter
from ..config import ClangFormatOptions
from .clang_format import invoke_clang_format
......@@ -28,8 +27,6 @@ class HeaderOnlyEmitter(AbstractEmitter):
return (path.join(self._output_directory, self._header_filename),)
def write_files(self, ctx: SfgContext):
ctx = prepare_context(ctx)
header_printer = SfgHeaderPrinter(ctx, self._ospec)
header = header_printer.get_code()
if self._clang_format is not None:
......
from __future__ import annotations
from typing import TYPE_CHECKING
from functools import reduce
from ..exceptions import SfgException
from ..ir import SfgCallTreeNode
from ..ir.source_components import (
SfgFunction,
SfgClass,
SfgConstructor,
SfgMemberVariable,
SfgInClassDefinition,
)
from ..context import SfgContext
if TYPE_CHECKING:
from ..ir.source_components import SfgHeaderInclude
class CollectIncludes:
def __call__(self, obj: object) -> set[SfgHeaderInclude]:
return self.visit(obj)
def visit(self, obj: object) -> set[SfgHeaderInclude]:
match obj:
case SfgContext():
includes = set()
for func in obj.functions():
includes |= self.visit(func)
for cls in obj.classes():
includes |= self.visit(cls)
return includes
case SfgCallTreeNode():
return reduce(
lambda accu, child: accu | self.visit(child),
obj.children,
obj.required_includes,
)
case SfgFunction(_, tree, _):
return self.visit(tree)
case SfgClass():
return reduce(
lambda accu, member: accu | (self.visit(member)),
obj.members(),
set(),
)
case SfgConstructor():
return reduce(
lambda accu, obj: accu | obj.required_includes,
obj.parameters,
set(),
)
case SfgMemberVariable():
return obj.required_includes
case SfgInClassDefinition():
return set()
case _:
raise SfgException(
f"Can't collect includes from object of type {type(obj)}"
)
def prepare_context(ctx: SfgContext):
"""Prepares a populated context for printing. Make sure to run this function on the
[SfgContext][pystencilssfg.SfgContext] before passing it to a printer.
Steps:
- Collection of includes: All defined functions and classes are traversed to collect all required
header includes
"""
# Collect all includes
required_includes = CollectIncludes().visit(ctx)
for incl in required_includes:
ctx.add_include(incl)
return ctx
......@@ -115,7 +115,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter):
def function(self, func: SfgFunction):
params = sorted(list(func.parameters), key=lambda p: p.name)
param_list = ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params)
return f"{func.return_type} {func.name} ( {param_list} );"
return f"{func.return_type.c_string()} {func.name} ( {param_list} );"
@visit.case(SfgClass)
def sfg_class(self, cls: SfgClass):
......@@ -241,7 +241,7 @@ class SfgImplPrinter(SfgGeneralPrinter):
def function(self, func: SfgFunction) -> str:
inline_prefix = "inline " if self._inline_impl else ""
code = (
f"{inline_prefix} {func.return_type} {func.name} ({self.param_list(func)})"
f"{inline_prefix} {func.return_type.c_string()} {func.name} ({self.param_list(func)})"
)
code += (
"{\n" + self._ctx.codestyle.indent(func.tree.get_code(self._ctx)) + "}\n"
......
......@@ -3,8 +3,10 @@ from typing import Sequence
from enum import Enum
import re
from pystencils.types import PsType, PsCustomType
from pystencils.enums import Target
from pystencils.types import UserTypeSpec, PsType, PsCustomType, create_type
from pystencils import Target
from pystencilssfg.composer.basic_composer import SequencerArg
from ..exceptions import SfgException
from ..context import SfgContext
......@@ -13,15 +15,20 @@ from ..composer import (
SfgClassComposer,
SfgComposer,
SfgComposerMixIn,
make_sequence,
)
from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude
from ..ir.source_components import SfgKernelHandle
from ..ir import (
SfgCallTreeNode,
SfgCallTreeLeaf,
SfgKernelCallNode,
)
from ..lang import SfgVar, AugExpr
from ..lang import SfgVar, AugExpr, cpptype, Ref, VarLike, _VarLike, asvar
from ..lang.cpp.sycl_accessor import SyclAccessor
accessor = SyclAccessor
class SyclComposerMixIn(SfgComposerMixIn):
......@@ -35,9 +42,8 @@ class SyclComposerMixIn(SfgComposerMixIn):
"""Obtain a `SyclHandler`, which represents a ``sycl::handler`` object."""
return SyclGroup(dims, self._ctx).var(name)
def sycl_range(self, dims: int, name: str, ref: bool = False) -> SfgVar:
ref_str = " &" if ref else ""
return SfgVar(name, PsCustomType(f"sycl::range< {dims} >{ref_str}"))
def sycl_range(self, dims: int, name: str, ref: bool = False) -> SyclRange:
return SyclRange(dims, ref=ref).var(name)
class SyclComposer(SfgBasicComposer, SfgClassComposer, SyclComposerMixIn):
......@@ -47,29 +53,54 @@ class SyclComposer(SfgBasicComposer, SfgClassComposer, SyclComposerMixIn):
super().__init__(sfg)
class SyclRange(AugExpr):
_template = cpptype("sycl::range< {dims} >", "<sycl/sycl.hpp>")
def __init__(self, dims: int, const: bool = False, ref: bool = False):
dtype = self._template(dims=dims, const=const)
if ref:
dtype = Ref(dtype)
super().__init__(dtype)
class SyclHandler(AugExpr):
"""Represents a SYCL command group handler (``sycl::handler``)."""
_type = cpptype("sycl::handler", "<sycl/sycl.hpp>")
def __init__(self, ctx: SfgContext):
dtype = PsCustomType("sycl::handler &")
dtype = Ref(self._type())
super().__init__(dtype)
self._ctx = ctx
def parallel_for(self, range: SfgVar | Sequence[int], kernel: SfgKernelHandle):
def parallel_for(
self,
range: VarLike | Sequence[int],
):
"""Generate a ``parallel_for`` kernel invocation using this command group handler.
The syntax of this uses a chain of two calls to mimic C++ syntax:
.. code-block:: Python
sfg.parallel_for(range)(
# Body
)
The body is constructed via sequencing (see `make_sequence`).
Args:
range: Object, or tuple of integers, indicating the kernel's iteration range
kernel: Handle to the pystencils-kernel to be executed
"""
self._ctx.add_include(SfgHeaderInclude("sycl/sycl.hpp", system_header=True))
if isinstance(range, _VarLike):
range = asvar(range)
kfunc = kernel.get_kernel_function()
if kfunc.target != Target.SYCL:
raise SfgException(
f"Kernel given to `parallel_for` is no SYCL kernel: {kernel.kernel_name}"
)
def check_kernel(kernel: SfgKernelHandle):
kfunc = kernel.get_kernel_function()
if kfunc.target != Target.SYCL:
raise SfgException(
f"Kernel given to `parallel_for` is no SYCL kernel: {kernel.kernel_name}"
)
id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>")
......@@ -79,26 +110,43 @@ class SyclHandler(AugExpr):
and id_regex.search(param.dtype.c_string()) is not None
)
id_param = list(filter(filter_id, kernel.scalar_parameters))[0]
tree = SfgKernelCallNode(kernel)
def sequencer(*args: SequencerArg):
id_param = []
for arg in args:
if isinstance(arg, SfgKernelCallNode):
check_kernel(arg._kernel_handle)
id_param.append(
list(filter(filter_id, arg._kernel_handle.scalar_parameters))[0]
)
if not all(item == id_param[0] for item in id_param):
raise ValueError(
"id_param should be the same for all kernels in parallel_for"
)
tree = make_sequence(*args)
kernel_lambda = SfgLambda(("=",), (id_param[0],), tree, None)
return SyclKernelInvoke(
self, SyclInvokeType.ParallelFor, range, kernel_lambda
)
kernel_lambda = SfgLambda(("=",), (id_param,), tree, None)
return SyclKernelInvoke(self, SyclInvokeType.ParallelFor, range, kernel_lambda)
return sequencer
class SyclGroup(AugExpr):
"""Represents a SYCL group (``sycl::group``)."""
_template = cpptype("sycl::group< {dims} >", "<sycl/sycl.hpp>")
def __init__(self, dimensions: int, ctx: SfgContext):
dtype = PsCustomType(f"sycl::group< {dimensions} > &")
dtype = Ref(self._template(dims=dimensions))
super().__init__(dtype)
self._dimensions = dimensions
self._ctx = ctx
def parallel_for_work_item(
self, range: SfgVar | Sequence[int], kernel: SfgKernelHandle
self, range: VarLike | Sequence[int], kernel: SfgKernelHandle
):
"""Generate a ``parallel_for_work_item` kernel invocation on this group.`
......@@ -106,8 +154,8 @@ class SyclGroup(AugExpr):
range: Object, or tuple of integers, indicating the kernel's iteration range
kernel: Handle to the pystencils-kernel to be executed
"""
self._ctx.add_include(SfgHeaderInclude("sycl/sycl.hpp", system_header=True))
if isinstance(range, _VarLike):
range = asvar(range)
kfunc = kernel.get_kernel_function()
if kfunc.target != Target.SYCL:
......@@ -128,18 +176,15 @@ class SyclGroup(AugExpr):
comp = SfgComposer(self._ctx)
tree = comp.seq(
comp.map_param(
id_param,
h_item,
f"{id_param.dtype.c_string()} {id_param.name} = {h_item}.get_local_id();",
),
comp.set_param(id_param, AugExpr.format("{}.get_local_id()", h_item)),
SfgKernelCallNode(kernel),
)
kernel_lambda = SfgLambda(("=",), (h_item,), tree, None)
return SyclKernelInvoke(
invoke = SyclKernelInvoke(
self, SyclInvokeType.ParallelForWorkItem, range, kernel_lambda
)
return invoke
class SfgLambda:
......@@ -150,12 +195,14 @@ class SfgLambda:
captures: Sequence[str],
params: Sequence[SfgVar],
tree: SfgCallTreeNode,
return_type: PsType | None = None,
return_type: UserTypeSpec | None = None,
) -> None:
self._captures = tuple(captures)
self._params = tuple(params)
self._tree = tree
self._return_type = return_type
self._return_type: PsType | None = (
create_type(return_type) if return_type is not None else None
)
from ..ir.postprocessing import CallTreePostProcessing
......@@ -234,7 +281,7 @@ class SyclKernelInvoke(SfgCallTreeLeaf):
)
self._lambda = lamb
self._required_params = invoker.depends | lamb.required_parameters
self._required_params = set(invoker.depends | lamb.required_parameters)
if isinstance(range, SfgVar):
self._required_params.add(range)
......
......@@ -74,9 +74,10 @@ class SourceFileGenerator:
project_info=cli_params.get_project_info(),
)
from pystencilssfg.ir import SfgHeaderInclude
from .lang import HeaderFile
from .ir import SfgHeaderInclude
self._context.add_include(SfgHeaderInclude("cstdint", system_header=True))
self._context.add_include(SfgHeaderInclude(HeaderFile("cstdint", system_header=True)))
self._context.add_definition("#define RESTRICT __restrict__")
output_mode = config.get_option("output_mode")
......@@ -114,4 +115,9 @@ class SourceFileGenerator:
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
# Collect header files for inclusion
from .ir import SfgHeaderInclude, collect_includes
for header in collect_includes(self._context):
self._context.add_include(SfgHeaderInclude(header))
self._emitter.write_files(self._context)
......@@ -31,6 +31,7 @@ from .source_components import (
SfgConstructor,
SfgClass,
)
from .analysis import collect_includes
__all__ = [
"SfgCallTreeNode",
......@@ -61,4 +62,5 @@ __all__ = [
"SfgMethod",
"SfgConstructor",
"SfgClass",
"collect_includes"
]
from __future__ import annotations
from typing import Any
from functools import reduce
from ..exceptions import SfgException
from ..lang import HeaderFile, includes
def collect_includes(obj: Any) -> set[HeaderFile]:
from ..context import SfgContext
from .call_tree import SfgCallTreeNode
from .source_components import (
SfgFunction,
SfgClass,
SfgConstructor,
SfgMemberVariable,
SfgInClassDefinition,
)
match obj:
case SfgContext():
headers = set()
for func in obj.functions():
headers |= collect_includes(func)
for cls in obj.classes():
headers |= collect_includes(cls)
return headers
case SfgCallTreeNode():
return reduce(
lambda accu, child: accu | collect_includes(child),
obj.children,
obj.required_includes,
)
case SfgFunction(_, tree, parameters):
param_headers: set[HeaderFile] = reduce(
set.union, (includes(p) for p in parameters), set()
)
return param_headers | collect_includes(tree)
case SfgClass():
return reduce(
lambda accu, member: accu | (collect_includes(member)),
obj.members(),
set(),
)
case SfgConstructor(parameters):
param_headers = reduce(
set.union, (includes(p) for p in parameters), set()
)
return param_headers
case SfgMemberVariable():
return includes(obj)
case SfgInClassDefinition():
return set()
case _:
raise SfgException(
f"Can't collect includes from object of type {type(obj)}"
)