diff --git a/src/pystencilssfg/__init__.py b/src/pystencilssfg/__init__.py index 7fa3378c7958c13f4fb526eb67a2bd08335400f5..b5ac38f9c9487c6b8caf77a72cda55ea7fc1e792 100644 --- a/src/pystencilssfg/__init__.py +++ b/src/pystencilssfg/__init__.py @@ -2,7 +2,8 @@ from .configuration import SfgConfiguration, SfgOutputMode, SfgCodeStyle from .generator import SourceFileGenerator from .composer import SfgComposer from .context import SfgContext -from .lang import AugExpr +from .lang import SfgVar, AugExpr +from .exceptions import SfgException __all__ = [ "SourceFileGenerator", @@ -11,7 +12,9 @@ __all__ = [ "SfgOutputMode", "SfgCodeStyle", "SfgContext", + "SfgVar", "AugExpr", + "SfgException", ] from . import _version diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index 6de8e184231a420709999ad094136154a813ed33..851a981862cb900442d92624bbaef0dbece68f89 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -117,6 +117,11 @@ class PostProcessingContext: "and\n" f" {live_var.name_and_type()}\n" ) + elif deconstify(var.dtype) == deconstify(live_var.dtype): + # Same type, just different constness + # One of them must be non-const -> keep the non-const one + if live_var.dtype.const and not var.dtype.const: + self._live_variables[var.name] = var else: raise SfgException( "Encountered two variables with same name but different data types:\n" diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 1a15fcd84f3c8200bfdd44df20c4c4b7f2ecc00f..d8bd782090fb625afc95ffc46207f2c2c7a9d8bd 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -3,8 +3,10 @@ from typing import Iterable, TypeAlias, Any, TYPE_CHECKING from itertools import chain from abc import ABC, abstractmethod +import sympy as sp + from pystencils import TypedSymbol -from pystencils.types import PsType +from pystencils.types import PsType, UserTypeSpec, create_type from ..exceptions import SfgException @@ -38,14 +40,14 @@ class SfgVar: def __init__( self, name: str, - dtype: PsType, + dtype: UserTypeSpec, required_includes: set[SfgHeaderInclude] | None = None, ): # TODO: Replace `required_includes` by using a property # Includes attached this way may currently easily be lost during postprocessing, # since they are not part of `_args` self._name = name - self._dtype = dtype + self._dtype = create_type(dtype) self._required_includes = ( required_includes if required_includes is not None else set() @@ -82,7 +84,7 @@ class SfgVar: return self._name def __repr__(self) -> str: - return f"{self._name}: {self._dtype}" + return self.name_and_type() class DependentExpression: @@ -153,8 +155,8 @@ class AugExpr: __match_args__ = ("expr", "dtype") - def __init__(self, dtype: PsType | None = None): - self._dtype = dtype + def __init__(self, dtype: UserTypeSpec | None = None): + self._dtype = create_type(dtype) if dtype is not None else None self._bound: DependentExpression | None = None self._is_variable = False @@ -173,11 +175,22 @@ class AugExpr: return AugExpr().bind(fmt, *deps, **kwdeps) def bind(self, fmt: str, *deps, **kwdeps): - depends = filter( - lambda obj: isinstance(obj, (SfgVar, AugExpr)), chain(deps, kwdeps.values()) - ) + dependencies: set[SfgVar] = set() + + from pystencils.sympyextensions import is_constant + + for expr in chain(deps, kwdeps.values()): + if isinstance(expr, _ExprLike): + dependencies |= depends(expr) + elif isinstance(expr, sp.Expr) and not is_constant(expr): + raise ValueError( + f"Cannot parse SymPy expression as C++ expression: {expr}\n" + " * pystencils-sfg is currently unable to parse non-constant SymPy expressions " + "since they contain symbols without type information." + ) + code = fmt.format(*deps, **kwdeps) - self._bind(DependentExpression(code, depends)) + self._bind(DependentExpression(code, dependencies)) return self def expr(self) -> DependentExpression: @@ -186,6 +199,12 @@ class AugExpr: return self._bound + @property + def code(self) -> str: + if self._bound is None: + raise SfgException("No syntax bound to this AugExpr.") + return str(self._bound) + @property def depends(self) -> set[SfgVar]: if self._bound is None: diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py index 1fc6e827aa11238b64ee9a05695ca7caa9dbc534..e144024fa831bf5c31c17b7efc93d82857654671 100644 --- a/tests/ir/test_postprocessing.py +++ b/tests/ir/test_postprocessing.py @@ -57,7 +57,7 @@ def test_find_sympy_symbols(): call_tree = make_sequence( sfg.set_param(x, b), - sfg.set_param(y, sfg.expr("{} / {}", x, a)), + sfg.set_param(y, sfg.expr("{} / {}", x.name, a)), sfg.call(khandle), ) diff --git a/tests/lang/test_expressions.py b/tests/lang/test_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..ef2f1943c07b3f0e1a23750302f043c0ebe89105 --- /dev/null +++ b/tests/lang/test_expressions.py @@ -0,0 +1,95 @@ +import pytest + +from pystencilssfg import SfgException +from pystencilssfg.lang import asvar, SfgVar, AugExpr + +import sympy as sp + +from pystencils import TypedSymbol, DynamicType + + +def test_asvar(): + # SfgVar must be returned as-is + var = SfgVar("p", "uint64") + assert var is asvar(var) + + # TypedSymbol is transformed + ts = TypedSymbol("q", "int32") + assert asvar(ts) == SfgVar("q", "int32") + + # Variable AugExprs get lowered to SfgVar + augexpr = AugExpr("uint16").var("l") + assert asvar(augexpr) == SfgVar("l", "uint16") + + # Complex AugExprs cannot be parsed + cexpr = AugExpr.format("{} + {}", SfgVar("m", "int32"), AugExpr("int32").var("n")) + with pytest.raises(SfgException): + _ = asvar(cexpr) + + # Untyped SymPy symbols won't be parsed + x = sp.Symbol("x") + with pytest.raises(ValueError): + _ = asvar(x) + + # Dynamically typed TypedSymbols cannot be parsed + y = TypedSymbol("y", DynamicType.NUMERIC_TYPE) + with pytest.raises(ValueError): + _ = asvar(y) + + +def test_augexpr_format(): + expr = AugExpr.format("std::vector< real_t > {{ 0.1, 0.2, 0.3 }}") + assert expr.code == "std::vector< real_t > { 0.1, 0.2, 0.3 }" + assert not expr.depends + + expr = AugExpr("int").var("p") + assert expr.code == "p" + assert expr.depends == {SfgVar("p", "int")} + + expr = AugExpr.format( + "{} + {} / {}", + AugExpr("int").var("p"), + AugExpr("int").var("q"), + AugExpr("uint32").var("r"), + ) + + assert str(expr) == expr.code == "p + q / r" + + assert expr.depends == { + SfgVar("p", "int"), + SfgVar("q", "int"), + SfgVar("r", "uint32"), + } + + # Must find TypedSymbols as dependencies + expr = AugExpr.format( + "{} + {} / {}", + AugExpr("int").var("p"), + TypedSymbol("x", "int32"), + TypedSymbol("y", "int32"), + ) + + assert expr.code == "p + x / y" + assert expr.depends == { + SfgVar("p", "int"), + SfgVar("x", "int32"), + SfgVar("y", "int32"), + } + + # Can parse constant SymPy expressions + expr = AugExpr.format("{}", sp.sympify(1)) + + assert expr.code == "1" + assert not expr.depends + + +def test_augexpr_illegal_format(): + x, y, z = sp.symbols("x, y, z") + + with pytest.raises(ValueError): + # Cannot parse SymPy symbols + _ = AugExpr.format("{}", x) + + with pytest.raises(ValueError): + # Cannot parse expressions containing symbols + _ = AugExpr.format("{} + {}", x + 3, y / (2 * z))