Skip to content
Snippets Groups Projects
Commit 3534ed16 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Toward cleaning up variables and expressions in the composer

parent 91889646
No related branches found
No related tags found
No related merge requests found
Pipeline #69644 canceled
from __future__ import annotations from __future__ import annotations
from typing import Sequence from typing import Sequence, TypeAlias
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from functools import reduce from functools import reduce
from pystencils import Field from pystencils import Field, TypedSymbol
from pystencils.backend import KernelParameter, KernelFunction from pystencils.backend import KernelParameter, KernelFunction
from pystencils.types import create_type, UserTypeSpec, PsCustomType, PsPointerType from pystencils.types import (
create_type,
UserTypeSpec,
PsCustomType,
PsPointerType,
PsType,
)
from ..context import SfgContext from ..context import SfgContext
from .custom import CustomGenerator from .custom import CustomGenerator
...@@ -58,8 +64,23 @@ class SfgNodeBuilder(ABC): ...@@ -58,8 +64,23 @@ class SfgNodeBuilder(ABC):
pass pass
ExprLike = str | SfgVar | AugExpr _ExprLike = (str, AugExpr, TypedSymbol)
SequencerArg = tuple | str | AugExpr | SfgCallTreeNode | SfgNodeBuilder ExprLike: TypeAlias = str | AugExpr | TypedSymbol
"""Things that may act as a C++ expression.
Expressions need not necesserily have a known data type.
"""
_VarLike = (TypedSymbol, AugExpr)
VarLike: TypeAlias = TypedSymbol | AugExpr
"""Things that may act as a variable.
Variables must always define their name *and* data type.
"""
_SequencerArg = (tuple, ExprLike, SfgCallTreeNode, SfgNodeBuilder)
SequencerArg: TypeAlias = tuple | ExprLike | SfgCallTreeNode | SfgNodeBuilder
"""Valid arguments to `make_sequence` and any sequencer that uses it."""
class SfgBasicComposer(SfgIComposer): class SfgBasicComposer(SfgIComposer):
...@@ -208,7 +229,11 @@ class SfgBasicComposer(SfgIComposer): ...@@ -208,7 +229,11 @@ class SfgBasicComposer(SfgIComposer):
num_blocks_str = str(num_blocks) num_blocks_str = str(num_blocks)
tpb_str = str(threads_per_block) tpb_str = str(threads_per_block)
stream_str = str(stream) if stream is not None else None stream_str = str(stream) if stream is not None else None
depends = _depends(num_blocks) | _depends(threads_per_block) | _depends(stream)
depends = _depends(num_blocks) | _depends(threads_per_block)
if stream is not None:
depends |= _depends(stream)
return SfgCudaKernelInvocation( return SfgCudaKernelInvocation(
kernel_handle, num_blocks_str, tpb_str, stream_str, depends kernel_handle, num_blocks_str, tpb_str, stream_str, depends
) )
...@@ -217,9 +242,9 @@ class SfgBasicComposer(SfgIComposer): ...@@ -217,9 +242,9 @@ class SfgBasicComposer(SfgIComposer):
"""Syntax sequencing. For details, see `make_sequence`""" """Syntax sequencing. For details, see `make_sequence`"""
return make_sequence(*args) return make_sequence(*args)
def params(self, *args: SfgVar) -> SfgFunctionParams: def params(self, *args: AugExpr) -> SfgFunctionParams:
"""Use inside a function body to add parameters to the function.""" """Use inside a function body to add parameters to the function."""
return SfgFunctionParams(args) return SfgFunctionParams([x.as_variable() for x in args])
def require(self, *includes: str | SfgHeaderInclude) -> SfgRequireIncludes: def require(self, *includes: str | SfgHeaderInclude) -> SfgRequireIncludes:
return SfgRequireIncludes( return SfgRequireIncludes(
...@@ -232,7 +257,7 @@ class SfgBasicComposer(SfgIComposer): ...@@ -232,7 +257,7 @@ class SfgBasicComposer(SfgIComposer):
ptr: bool = False, ptr: bool = False,
ref: bool = False, ref: bool = False,
const: bool = False, const: bool = False,
): ) -> PsType:
if ptr and ref: if ptr and ref:
raise SfgException("Create either a pointer, or a ref type, not both!") raise SfgException("Create either a pointer, or a ref type, not both!")
...@@ -250,11 +275,11 @@ class SfgBasicComposer(SfgIComposer): ...@@ -250,11 +275,11 @@ class SfgBasicComposer(SfgIComposer):
else: else:
return base_type return base_type
def var(self, name: str, dtype: UserTypeSpec) -> SfgVar: def var(self, name: str, dtype: UserTypeSpec) -> AugExpr:
"""Create a variable with given name and data type.""" """Create a variable with given name and data type."""
return SfgVar(name, create_type(dtype)) return AugExpr(create_type(dtype)).var(name)
def init(self, lhs: SfgVar) -> SfgInplaceInitBuilder: def init(self, lhs: VarLike) -> SfgInplaceInitBuilder:
"""Create a C++ in-place initialization. """Create a C++ in-place initialization.
Usage: Usage:
...@@ -270,7 +295,7 @@ class SfgBasicComposer(SfgIComposer): ...@@ -270,7 +295,7 @@ class SfgBasicComposer(SfgIComposer):
SomeClass obj { arg1, arg2, arg3 }; SomeClass obj { arg1, arg2, arg3 };
""" """
return SfgInplaceInitBuilder(lhs) return SfgInplaceInitBuilder(_asvar(lhs))
def expr(self, fmt: str, *deps, **kwdeps): def expr(self, fmt: str, *deps, **kwdeps):
return AugExpr.format(fmt, *deps, **kwdeps) return AugExpr.format(fmt, *deps, **kwdeps)
...@@ -329,15 +354,8 @@ class SfgBasicComposer(SfgIComposer): ...@@ -329,15 +354,8 @@ class SfgBasicComposer(SfgIComposer):
def make_statements(arg: ExprLike) -> SfgStatements: def make_statements(arg: ExprLike) -> SfgStatements:
match arg: depends = _depends(arg)
case str(): return SfgStatements(str(arg), (), depends)
return SfgStatements(arg, (), ())
case SfgVar(name, _):
return SfgStatements(name, (), (arg,))
case AugExpr():
return SfgStatements(str(arg), (), arg.depends)
case _:
assert False
def make_sequence(*args: SequencerArg) -> SfgSequence: def make_sequence(*args: SequencerArg) -> SfgSequence:
...@@ -392,10 +410,8 @@ def make_sequence(*args: SequencerArg) -> SfgSequence: ...@@ -392,10 +410,8 @@ def make_sequence(*args: SequencerArg) -> SfgSequence:
children.append(arg.resolve()) children.append(arg.resolve())
elif isinstance(arg, SfgCallTreeNode): elif isinstance(arg, SfgCallTreeNode):
children.append(arg) children.append(arg)
elif isinstance(arg, AugExpr): elif isinstance(arg, _ExprLike):
children.append(SfgStatements(str(arg), (), arg.depends)) children.append(make_statements(arg))
elif isinstance(arg, str):
children.append(SfgStatements(arg, (), ()))
elif isinstance(arg, tuple): elif isinstance(arg, tuple):
# Tuples are treated as blocks # Tuples are treated as blocks
subseq = make_sequence(*arg) subseq = make_sequence(*arg)
...@@ -407,25 +423,20 @@ def make_sequence(*args: SequencerArg) -> SfgSequence: ...@@ -407,25 +423,20 @@ def make_sequence(*args: SequencerArg) -> SfgSequence:
class SfgInplaceInitBuilder(SfgNodeBuilder): class SfgInplaceInitBuilder(SfgNodeBuilder):
def __init__(self, lhs: SfgVar | AugExpr) -> None: def __init__(self, lhs: SfgVar) -> None:
if isinstance(lhs, AugExpr):
lhs = lhs.as_variable()
self._lhs: SfgVar = lhs self._lhs: SfgVar = lhs
self._depends: set[SfgVar] = set() self._depends: set[SfgVar] = set()
self._rhs: str | None = None self._rhs: str | None = None
def __call__( def __call__(
self, self,
*rhs: str | AugExpr, *rhs: ExprLike,
) -> SfgInplaceInitBuilder: ) -> SfgInplaceInitBuilder:
if self._rhs is not None: if self._rhs is not None:
raise SfgException("Assignment builder used multiple times.") raise SfgException("Assignment builder used multiple times.")
self._rhs = ", ".join(str(expr) for expr in rhs) self._rhs = ", ".join(str(expr) for expr in rhs)
self._depends = reduce( self._depends = reduce(set.union, (_depends(obj) for obj in rhs), set())
set.union, (obj.depends for obj in rhs if isinstance(obj, AugExpr)), set()
)
return self return self
def resolve(self) -> SfgCallTreeNode: def resolve(self) -> SfgCallTreeNode:
...@@ -538,12 +549,30 @@ def struct_from_numpy_dtype( ...@@ -538,12 +549,30 @@ def struct_from_numpy_dtype(
return cls return cls
def _depends(expr: ExprLike | Sequence[ExprLike] | None) -> set[SfgVar]: def _asvar(var: VarLike) -> SfgVar:
match var:
case AugExpr():
return var.as_variable()
case TypedSymbol():
from pystencils import DynamicType
if isinstance(var.dtype, DynamicType):
raise SfgException(
f"Unable to cast dynamically typed symbol {var} to a variable.\n"
f"{var} has dynamic type {var.dtype}, which cannot be resolved to a type outside of a kernel."
)
return SfgVar(var.name, var.dtype)
case _:
raise ValueError(f"Invalid variable: {var}")
def _depends(expr: ExprLike) -> set[SfgVar]:
match expr: match expr:
case None | str(): case None | str():
return set() return set()
case SfgVar(): case TypedSymbol():
return {expr} return {_asvar(expr)}
case AugExpr(): case AugExpr():
return expr.depends return expr.depends
case _: case _:
......
from __future__ import annotations from __future__ import annotations
from typing import Sequence from typing import Sequence
from pystencils import TypedSymbol
from pystencils.types import PsCustomType, UserTypeSpec from pystencils.types import PsCustomType, UserTypeSpec
from ..lang import AugExpr
from ..ir import SfgCallTreeNode from ..ir import SfgCallTreeNode
from ..ir.source_components import ( from ..ir.source_components import (
SfgClass, SfgClass,
...@@ -14,12 +16,18 @@ from ..ir.source_components import ( ...@@ -14,12 +16,18 @@ from ..ir.source_components import (
SfgClassKeyword, SfgClassKeyword,
SfgVisibility, SfgVisibility,
SfgVisibilityBlock, SfgVisibilityBlock,
SfgVar,
) )
from ..exceptions import SfgException from ..exceptions import SfgException
from .mixin import SfgComposerMixIn from .mixin import SfgComposerMixIn
from .basic_composer import SfgNodeBuilder, make_sequence from .basic_composer import (
SfgNodeBuilder,
make_sequence,
_VarLike,
VarLike,
ExprLike,
_asvar,
)
class SfgClassComposer(SfgComposerMixIn): class SfgClassComposer(SfgComposerMixIn):
...@@ -46,7 +54,7 @@ class SfgClassComposer(SfgComposerMixIn): ...@@ -46,7 +54,7 @@ class SfgClassComposer(SfgComposerMixIn):
def __call__( def __call__(
self, self,
*args: ( *args: (
SfgClassMember | SfgClassComposer.ConstructorBuilder | SfgVar | str SfgClassMember | SfgClassComposer.ConstructorBuilder | VarLike | str
), ),
): ):
for arg in args: for arg in args:
...@@ -63,15 +71,21 @@ class SfgClassComposer(SfgComposerMixIn): ...@@ -63,15 +71,21 @@ class SfgClassComposer(SfgComposerMixIn):
Returned by `constructor`. Returned by `constructor`.
""" """
def __init__(self, *params: SfgVar): def __init__(self, *params: VarLike):
self._params = params self._params = tuple(_asvar(p) for p in params)
self._initializers: list[str] = [] self._initializers: list[str] = []
self._body: str | None = None self._body: str | None = None
def init(self, initializer: str) -> SfgClassComposer.ConstructorBuilder: def init(self, var: VarLike):
"""Add an initialization expression to the constructor's initializer list.""" """Add an initialization expression to the constructor's initializer list."""
self._initializers.append(initializer)
return self def init_sequencer(expr: ExprLike):
expr = str(expr)
initializer = f"{_asvar(var)}{{ {expr} }}"
self._initializers.append(initializer)
return self
return init_sequencer
def body(self, body: str): def body(self, body: str):
"""Define the constructor body""" """Define the constructor body"""
...@@ -120,7 +134,7 @@ class SfgClassComposer(SfgComposerMixIn): ...@@ -120,7 +134,7 @@ class SfgClassComposer(SfgComposerMixIn):
"""Create a `private` visibility block in a class or struct body""" """Create a `private` visibility block in a class or struct body"""
return SfgClassComposer.VisibilityContext(SfgVisibility.PRIVATE) return SfgClassComposer.VisibilityContext(SfgVisibility.PRIVATE)
def constructor(self, *params: SfgVar): def constructor(self, *params: VarLike):
"""In a class or struct body or visibility block, add a constructor. """In a class or struct body or visibility block, add a constructor.
Args: Args:
...@@ -171,7 +185,7 @@ class SfgClassComposer(SfgComposerMixIn): ...@@ -171,7 +185,7 @@ class SfgClassComposer(SfgComposerMixIn):
SfgClassComposer.VisibilityContext SfgClassComposer.VisibilityContext
| SfgClassMember | SfgClassMember
| SfgClassComposer.ConstructorBuilder | SfgClassComposer.ConstructorBuilder
| SfgVar | VarLike
| str | str
), ),
): ):
...@@ -186,9 +200,9 @@ class SfgClassComposer(SfgComposerMixIn): ...@@ -186,9 +200,9 @@ class SfgClassComposer(SfgComposerMixIn):
( (
SfgClassMember, SfgClassMember,
SfgClassComposer.ConstructorBuilder, SfgClassComposer.ConstructorBuilder,
SfgVar,
str, str,
), )
+ _VarLike,
): ):
if default_ended: if default_ended:
raise SfgException( raise SfgException(
...@@ -204,13 +218,17 @@ class SfgClassComposer(SfgComposerMixIn): ...@@ -204,13 +218,17 @@ class SfgClassComposer(SfgComposerMixIn):
@staticmethod @staticmethod
def _resolve_member( def _resolve_member(
arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | SfgVar | str, arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | VarLike | str,
): ) -> SfgClassMember:
if isinstance(arg, SfgVar): match arg:
return SfgMemberVariable(arg.name, arg.dtype) case AugExpr() | TypedSymbol():
elif isinstance(arg, str): var = _asvar(arg)
return SfgInClassDefinition(arg) return SfgMemberVariable(var.name, var.dtype)
elif isinstance(arg, SfgClassComposer.ConstructorBuilder): case str():
return arg.resolve() return SfgInClassDefinition(arg)
else: case SfgClassComposer.ConstructorBuilder():
return arg return arg.resolve()
case SfgClassMember():
return arg
case _:
raise ValueError(f"Invalid class member: {arg}")
...@@ -48,7 +48,7 @@ class DependentExpression: ...@@ -48,7 +48,7 @@ class DependentExpression:
def __add__(self, other: DependentExpression): def __add__(self, other: DependentExpression):
return DependentExpression(self.expr + other.expr, self.depends | other.depends) return DependentExpression(self.expr + other.expr, self.depends | other.depends)
class VarExpr(DependentExpression): class VarExpr(DependentExpression):
def __init__(self, var: SfgVar): def __init__(self, var: SfgVar):
...@@ -61,6 +61,8 @@ class VarExpr(DependentExpression): ...@@ -61,6 +61,8 @@ class VarExpr(DependentExpression):
class AugExpr: class AugExpr:
__match_args__ = ("expr", "dtype")
def __init__(self, dtype: PsType | None = None): def __init__(self, dtype: PsType | None = None):
self._dtype = dtype self._dtype = dtype
self._bound: DependentExpression | None = None self._bound: DependentExpression | None = None
...@@ -77,6 +79,7 @@ class AugExpr: ...@@ -77,6 +79,7 @@ class AugExpr:
@staticmethod @staticmethod
def format(fmt: str, *deps, **kwdeps) -> AugExpr: def format(fmt: str, *deps, **kwdeps) -> AugExpr:
"""Create a new `AugExpr` by combining existing expressions."""
return AugExpr().bind(fmt, *deps, **kwdeps) return AugExpr().bind(fmt, *deps, **kwdeps)
def bind(self, fmt: str, *deps, **kwdeps): def bind(self, fmt: str, *deps, **kwdeps):
...@@ -109,11 +112,11 @@ class AugExpr: ...@@ -109,11 +112,11 @@ class AugExpr:
raise SfgException("This AugExpr has no known data type.") raise SfgException("This AugExpr has no known data type.")
return self._dtype return self._dtype
@property @property
def is_variable(self) -> bool: def is_variable(self) -> bool:
return isinstance(self._bound, VarExpr) return isinstance(self._bound, VarExpr)
def as_variable(self) -> SfgVar: def as_variable(self) -> SfgVar:
if not isinstance(self._bound, VarExpr): if not isinstance(self._bound, VarExpr):
raise SfgException("This expression is not a variable") raise SfgException("This expression is not a variable")
......
#pragma once
#include <cstdint>
#define RESTRICT __restrict__
class Scale {
private:
float alpha;
public:
Scale(float alpha) : alpha{ alpha } {}
void operator() (float *const _data_f, float *const _data_g);
};
...@@ -20,4 +20,4 @@ with SourceFileGenerator() as sfg: ...@@ -20,4 +20,4 @@ with SourceFileGenerator() as sfg:
sfg.map_field(u_dst, mdspan_ref(u_dst)), sfg.map_field(u_dst, mdspan_ref(u_dst)),
sfg.map_field(f, mdspan_ref(f)), sfg.map_field(f, mdspan_ref(f)),
sfg.call(poisson_kernel) sfg.call(poisson_kernel)
) )
\ No newline at end of file
import sympy as sp
from pystencils import TypedSymbol, fields, kernel
from pystencilssfg import SourceFileGenerator, SfgConfiguration
with SourceFileGenerator() as sfg:
α = TypedSymbol("alpha", "float32")
f, g = fields("f, g: float32[10]")
@kernel
def scale():
f[0] @= α * g.center()
khandle = sfg.kernels.create(scale)
sfg.klass("Scale")(
sfg.private(α),
sfg.public(
sfg.constructor(α).init(α)(α.name),
sfg.method("operator()")(sfg.init(α)(f"this->{α}"), sfg.call(khandle)),
),
)
...@@ -15,20 +15,45 @@ EXPECTED_DIR = path.join(THIS_DIR, "expected") ...@@ -15,20 +15,45 @@ EXPECTED_DIR = path.join(THIS_DIR, "expected")
@dataclass @dataclass
class ScriptInfo: class ScriptInfo:
script_name: str script_name: str
"""Name of the generator script, without .py-extension.
Generator scripts must be located in the ``scripts`` folder.
"""
expected_outputs: tuple[str, ...] expected_outputs: tuple[str, ...]
"""List of file extensions expected to be emitted by the generator script.
Output files will all be placed in the ``out`` folder.
"""
compilable_output: str | None = None compilable_output: str | None = None
"""File extension of the output file that can be compiled.
If this is set, and the expected file exists, the ``compile_cmd`` will be
executed to check for error-free compilation of the output.
"""
compile_cmd: str = f"g++ --std=c++17 -I {THIS_DIR}/deps/mdspan/include" compile_cmd: str = f"g++ --std=c++17 -I {THIS_DIR}/deps/mdspan/include"
"""Command to be invoked to compile the generated source file."""
SCRIPTS = [ SCRIPTS = [
ScriptInfo("SimpleJacobi", ("h", "cpp"), compilable_output="cpp"), ScriptInfo("SimpleJacobi", ("h", "cpp"), compilable_output="cpp"),
ScriptInfo("SimpleClasses", ("h", "cpp")), ScriptInfo("SimpleClasses", ("h", "cpp")),
ScriptInfo("Variables", ("h", "cpp"), compilable_output="cpp"),
] ]
@pytest.mark.parametrize("script_info", SCRIPTS) @pytest.mark.parametrize("script_info", SCRIPTS)
def test_generator_script(script_info: ScriptInfo): def test_generator_script(script_info: ScriptInfo):
"""Test a generator script defined by ``script_info``.
The generator script will be run, with its output placed in the ``out`` folder.
If it is successful, its output files will be compared against
any files of the same name from the ``expected`` folder.
Finally, if any compilable files are specified, the test will attempt to compile them.
"""
script_name = script_info.script_name script_name = script_info.script_name
script_file = path.join(SCRIPTS_DIR, script_name + ".py") script_file = path.join(SCRIPTS_DIR, script_name + ".py")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment