diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 9f0673de3c109bbdbbbfa1eaf1a3ebec51af87cf..7fda6c9ae2b115bbafdc40ee670625ceec61a578 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -6,7 +6,7 @@ import sympy as sp from functools import reduce from pystencils import Field, TypedSymbol -from pystencils.backend import KernelParameter, KernelFunction +from pystencils.backend import KernelFunction from pystencils.types import ( create_type, UserTypeSpec, @@ -31,6 +31,7 @@ from ..ir import ( ) from ..ir.postprocessing import ( SfgDeferredParamMapping, + SfgDeferredParamSetter, SfgDeferredFieldMapping, SfgDeferredVectorMapping, ) @@ -298,6 +299,10 @@ class SfgBasicComposer(SfgIComposer): return SfgInplaceInitBuilder(_asvar(lhs)) def expr(self, fmt: str, *deps, **kwdeps): + """Create an expression while keeping track of variables it depends on. + + + """ return AugExpr.format(fmt, *deps, **kwdeps) @property @@ -331,26 +336,37 @@ class SfgBasicComposer(SfgIComposer): """ return SfgDeferredFieldMapping(field, index_provider) + def set_param(self, param: VarLike | sp.Symbol, expr: ExprLike): + depends = _depends(expr) + var = _asvar(param) if isinstance(param, _VarLike) else param + return SfgDeferredParamSetter(var, depends, str(expr)) + def map_param( self, - lhs: SfgVar, - rhs: SfgVar | Sequence[SfgVar], + param: VarLike | sp.Symbol, + depends: VarLike | Sequence[VarLike], mapping: str, ): """Arbitrary parameter mapping: Add a single line of code to define a left-hand side object from one or multiple right-hand side dependencies.""" - if isinstance(rhs, (KernelParameter, SfgVar)): - rhs = [rhs] - return SfgDeferredParamMapping(lhs, set(rhs), mapping) + if isinstance(depends, _VarLike): + depends = [depends] + lhs_var = _asvar(param) if isinstance(param, _VarLike) else param + return SfgDeferredParamMapping( + lhs_var, set(_asvar(v) for v in depends), mapping + ) - def map_vector(self, lhs_components: Sequence[SfgVar | sp.Symbol], rhs: SrcVector): + def map_vector(self, lhs_components: Sequence[VarLike | sp.Symbol], rhs: SrcVector): """Extracts scalar numerical values from a vector data type. Args: lhs_components: Vector components as a list of symbols. rhs: A `SrcVector` object representing a vector data structure. """ - return SfgDeferredVectorMapping(lhs_components, rhs) + components = [ + (_asvar(c) if isinstance(c, _VarLike) else c) for c in lhs_components + ] + return SfgDeferredVectorMapping(components, rhs) def make_statements(arg: ExprLike) -> SfgStatements: diff --git a/src/pystencilssfg/ir/call_tree.py b/src/pystencilssfg/ir/call_tree.py index 34d50182b7bca8348a161553e4f7edaa4ae3c0d0..4db0daf2a37c0b7e11c24e3d79735378aa78096b 100644 --- a/src/pystencilssfg/ir/call_tree.py +++ b/src/pystencilssfg/ir/call_tree.py @@ -123,6 +123,10 @@ class SfgStatements(SfgCallTreeLeaf): def required_includes(self) -> set[SfgHeaderInclude]: return self._required_includes + @property + def code_string(self) -> str: + return self._code_string + def get_code(self, ctx: SfgContext) -> str: return self._code_string diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index c30910d59f6c6626dd6fc147f93491e9b175dace..e796a69cd856abfd2cadabd6b28ea5292d310d7b 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Sequence, Iterable import warnings from functools import reduce from dataclasses import dataclass @@ -9,6 +9,7 @@ from abc import ABC, abstractmethod import sympy as sp from pystencils import Field, TypedSymbol +from pystencils.types import deconstify from pystencils.backend.kernelfunction import ( FieldPointerParam, FieldShapeParam, @@ -61,7 +62,7 @@ class FlattenSequences: class PostProcessingContext: def __init__(self, enclosing_class: SfgClass | None = None) -> None: self.enclosing_class: SfgClass | None = enclosing_class - self.live_objects: set[SfgVar] = set() + self._live_variables: dict[str, SfgVar] = dict() def is_method(self) -> bool: return self.enclosing_class is not None @@ -72,42 +73,73 @@ class PostProcessingContext: return self.enclosing_class + @property + def live_variables(self) -> set[SfgVar]: + return set(self._live_variables.values()) -@dataclass(frozen=True) -class PostProcessingResult: - function_params: set[SfgVar] + def get_live_variable(self, name: str) -> SfgVar | None: + return self._live_variables.get(name) + def _define(self, vars: Iterable[SfgVar], expr: str): + for var in vars: + if var.name in self._live_variables: + live_var = self._live_variables[var.name] -class CallTreePostProcessing: - def __init__(self, enclosing_class: SfgClass | None = None): - self._enclosing_class = enclosing_class - self._flattener = FlattenSequences() + live_var_dtype = live_var.dtype + def_dtype = var.dtype - def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult: - params = self.get_live_objects(ast) - params_by_name: dict[str, SfgVar] = dict() + # A const definition conflicts with a non-const live variable + # A non-const definition is always OK, but then the types must be the same + if (def_dtype.const and not live_var_dtype.const) or ( + deconstify(def_dtype) != deconstify(live_var_dtype) + ): + warnings.warn( + f"Type conflict at variable definition: Expected type {live_var_dtype}, but got {def_dtype}.\n" + f" * At definition {expr}", + UserWarning, + ) + + del self._live_variables[var.name] - for param in params: - if param.name in params_by_name: - other = params_by_name[param.name] + def _use(self, vars: Iterable[SfgVar]): + for var in vars: + if var.name in self._live_variables: + live_var = self._live_variables[var.name] - if param.dtype == other.dtype: + if var.dtype == live_var.dtype: + # This can only happen if the variables are SymbolLike, + # i.e. wrap a field-associated kernel parameter + # TODO: Once symbol properties are a thing, check and combine them here warnings.warn( - "Encountered two non-identical parameters with same name and data type:\n" - f" {repr(param)}\n" + "Encountered two non-identical variables with same name and data type:\n" + f" {var.name_and_type()}\n" "and\n" - f" {repr(other)}\n" + f" {live_var.name_and_type()}\n" ) else: raise SfgException( - "Encountered two parameters with same name but different data types:\n" - f" {repr(param)}\n" + "Encountered two variables with same name but different data types:\n" + f" {var.name_and_type()}\n" "and\n" - f" {repr(other)}" + f" {live_var.name_and_type()}" ) - params_by_name[param.name] = param + else: + self._live_variables[var.name] = var - return PostProcessingResult(set(params_by_name.values())) + +@dataclass(frozen=True) +class PostProcessingResult: + function_params: set[SfgVar] + + +class CallTreePostProcessing: + def __init__(self, enclosing_class: SfgClass | None = None): + self._enclosing_class = enclosing_class + self._flattener = FlattenSequences() + + def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult: + live_vars = self.get_live_variables(ast) + return PostProcessingResult(live_vars) def handle_sequence(self, seq: SfgSequence, ppc: PostProcessingContext): def iter_nested_sequences(seq: SfgSequence): @@ -122,18 +154,18 @@ class CallTreePostProcessing: iter_nested_sequences(c) else: if isinstance(c, SfgStatements): - ppc.live_objects -= c.defines + ppc._define(c.defines, c.code_string) - ppc.live_objects |= self.get_live_objects(c) + ppc._use(self.get_live_variables(c)) iter_nested_sequences(seq) - def get_live_objects(self, node: SfgCallTreeNode) -> set[SfgVar]: + def get_live_variables(self, node: SfgCallTreeNode) -> set[SfgVar]: match node: case SfgSequence(): ppc = self._ppc() self.handle_sequence(node, ppc) - return ppc.live_objects + return ppc.live_variables case SfgCallTreeLeaf(): return node.depends @@ -144,7 +176,7 @@ class CallTreePostProcessing: case _: return reduce( lambda x, y: x | y, - (self.get_live_objects(c) for c in node.children), + (self.get_live_variables(c) for c in node.children), set(), ) @@ -177,14 +209,30 @@ class SfgDeferredNode(SfgCallTreeNode, ABC): class SfgDeferredParamMapping(SfgDeferredNode): - def __init__(self, lhs: SfgVar, rhs: set[SfgVar], mapping: str): + def __init__(self, lhs: SfgVar | sp.Symbol, depends: set[SfgVar], mapping: str): self._lhs = lhs - self._rhs = rhs + self._depends = depends self._mapping = mapping def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: - if self._lhs in ppc.live_objects: - return SfgStatements(self._mapping, (self._lhs,), tuple(self._rhs)) + live_var = ppc.get_live_variable(self._lhs.name) + if live_var is not None: + return SfgStatements(self._mapping, (live_var,), tuple(self._depends)) + else: + return SfgSequence([]) + + +class SfgDeferredParamSetter(SfgDeferredNode): + def __init__(self, param: SfgVar | sp.Symbol, depends: set[SfgVar], rhs_expr: str): + self._lhs = param + self._depends = depends + self._rhs_expr = rhs_expr + + def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: + live_var = ppc.get_live_variable(self._lhs.name) + if live_var is not None: + code = f"{live_var.dtype} {live_var.name} = {self._rhs_expr};" + return SfgStatements(code, (live_var,), tuple(self._depends)) else: return SfgSequence([]) @@ -209,7 +257,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): self._field.strides ) - for param in ppc.live_objects: + for param in ppc.live_variables: # idk why, but mypy does not understand these pattern matches match param: case SfgSymbolLike(FieldPointerParam(_, _, field)) if field == self._field: # type: ignore @@ -288,7 +336,7 @@ class SfgDeferredVectorMapping(SfgDeferredNode): def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: nodes = [] - for param in ppc.live_objects: + for param in ppc.live_variables: if param.name in self._scalars: idx, _ = self._scalars[param.name] expr = self._vector.extract_component(idx) diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index 351c98984db665a5e00d696f6c44c1777d297906..2eab9935ae7cc5be7fa2d6a766ed22bd2ee121dd 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -243,6 +243,9 @@ class SfgVar: def required_includes(self) -> set[SfgHeaderInclude]: return self._required_includes + def name_and_type(self) -> str: + return f"{self._name}: {self._dtype}" + def __str__(self) -> str: return self._name diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..1fc6e827aa11238b64ee9a05695ca7caa9dbc534 --- /dev/null +++ b/tests/ir/test_postprocessing.py @@ -0,0 +1,77 @@ +import sympy as sp +from pystencils import fields, kernel, TypedSymbol + +from pystencilssfg import SfgContext, SfgComposer +from pystencilssfg.composer import make_sequence + +from pystencilssfg.ir import SfgStatements +from pystencilssfg.ir.postprocessing import CallTreePostProcessing + + +def test_live_vars(): + ctx = SfgContext() + sfg = SfgComposer(ctx) + + f, g = fields("f, g(2): double[2D]") + x, y = [TypedSymbol(n, "double") for n in "xy"] + z = sp.Symbol("z") + + @kernel + def update(): + f[0, 0] @= x * g.center(0) + y * g.center(1) - z + + khandle = sfg.kernels.create(update) + + a = sfg.var("a", "float") + b = sfg.var("b", "float") + + call_tree = make_sequence( + sfg.init(x)(a), sfg.init(y)(sfg.expr("{} - {}", b, x)), sfg.call(khandle) # # + ) + + pp = CallTreePostProcessing() + free_vars = pp.get_live_variables(call_tree) + + expected = {a.as_variable(), b.as_variable()} | set( + param for param in khandle.parameters if param.name not in "xy" + ) + + assert free_vars == expected + + +def test_find_sympy_symbols(): + ctx = SfgContext() + sfg = SfgComposer(ctx) + + f, g = fields("f, g(2): double[2D]") + x, y, z = sp.symbols("x, y, z") + + @kernel + def update(): + f[0, 0] @= x * g.center(0) + y * g.center(1) - z + + khandle = sfg.kernels.create(update) + + a = sfg.var("a", "double") + b = sfg.var("b", "double") + + call_tree = make_sequence( + sfg.set_param(x, b), + sfg.set_param(y, sfg.expr("{} / {}", x, a)), + sfg.call(khandle), + ) + + pp = CallTreePostProcessing() + live_vars = pp.get_live_variables(call_tree) + + expected = {a.as_variable(), b.as_variable()} | set( + param for param in khandle.parameters if param.name not in "xy" + ) + + assert live_vars == expected + + assert isinstance(call_tree.children[0], SfgStatements) + assert call_tree.children[0].code_string == "const double x = b;" + + assert isinstance(call_tree.children[1], SfgStatements) + assert call_tree.children[1].code_string == "const double y = x / a;"