diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 59938026ed670cb0e22c678220e7254312146eef..97966247d4be769dfcb800e0b25bcd631bdda69e 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -407,7 +407,10 @@ def make_sequence(*args: SequencerArg) -> SfgSequence: class SfgInplaceInitBuilder(SfgNodeBuilder): - def __init__(self, lhs: SfgVar) -> None: + def __init__(self, lhs: SfgVar | AugExpr) -> None: + if isinstance(lhs, AugExpr): + lhs = lhs.as_variable() + self._lhs: SfgVar = lhs self._depends: set[SfgVar] = set() self._rhs: str | None = None diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index 06788cf5918c8dddd3aa0ca687a938d3722879ae..351c98984db665a5e00d696f6c44c1777d297906 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -197,7 +197,7 @@ class SfgKernelHandle: @property def fields(self): - return self.fields + return self._fields def get_kernel_function(self) -> KernelFunction: return self._namespace.get_kernel_function(self) diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index c456194a11fd6da541c96b76f2bf1e5b3a8bc984..579646f3ee5b96dd3a57c1ece45c6b6407c48102 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -48,16 +48,27 @@ class DependentExpression: def __add__(self, other: DependentExpression): return DependentExpression(self.expr + other.expr, self.depends | other.depends) + + +class VarExpr(DependentExpression): + def __init__(self, var: SfgVar): + self._var = var + super().__init__(var.name, (var,)) + + @property + def variable(self) -> SfgVar: + return self._var class AugExpr: def __init__(self, dtype: PsType | None = None): self._dtype = dtype self._bound: DependentExpression | None = None + self._is_variable = False def var(self, name: str): v = SfgVar(name, self.get_dtype(), self.required_includes) - expr = DependentExpression(name, (v,)) + expr = VarExpr(v) return self._bind(expr) @staticmethod @@ -98,6 +109,15 @@ class AugExpr: raise SfgException("This AugExpr has no known data type.") return self._dtype + + @property + def is_variable(self) -> bool: + return isinstance(self._bound, VarExpr) + + def as_variable(self) -> SfgVar: + if not isinstance(self._bound, VarExpr): + raise SfgException("This expression is not a variable") + return self._bound.variable @property def required_includes(self) -> set[SfgHeaderInclude]: