diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index 4ee499126fe795f95284d9c40d02513f53333054..dc80202427036b8563f133666f71c5b4e60ec4ec 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -14,7 +14,7 @@ from ..composer import ( SfgComposer, SfgComposerMixIn, ) -from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude, SfgKernelParamVar +from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude from ..ir import ( SfgCallTreeNode, SfgCallTreeLeaf, @@ -73,7 +73,7 @@ class SyclHandler(AugExpr): id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>") - def filter_id(param: SfgKernelParamVar) -> bool: + def filter_id(param: SfgVar) -> bool: return ( isinstance(param.dtype, PsCustomType) and id_regex.search(param.dtype.c_string()) is not None @@ -117,7 +117,7 @@ class SyclGroup(AugExpr): id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>") - def filter_id(param: SfgKernelParamVar) -> bool: + def filter_id(param: SfgVar) -> bool: return ( isinstance(param.dtype, PsCustomType) and id_regex.search(param.dtype.c_string()) is not None diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index 4398938327242e8e1c1fb2cfc7787c72fa0d3138..cf4d103a2a93e363e73282a31eb0b6523f510c6e 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -163,7 +163,7 @@ class SfgKernelHandle: self._namespace = namespace self._parameters = [SfgKernelParamVar(p) for p in parameters] - self._scalar_params: set[SfgKernelParamVar] = set() + self._scalar_params: set[SfgVar] = set() self._fields: set[Field] = set() for param in self._parameters: @@ -193,7 +193,7 @@ class SfgKernelHandle: return self._parameters @property - def scalar_parameters(self): + def scalar_parameters(self) -> set[SfgVar]: return self._scalar_params @property diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index d67ffa0c845b16c867064365cec8b5af5ffe6a2a..b5532bf77464dbc1e32f8449158ec4ae34418f30 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -12,7 +12,7 @@ from .expressions import ( SrcVector, ) -from .types import Ref +from .types import Ref, strip_ptr_ref __all__ = [ "SfgVar", @@ -27,4 +27,5 @@ __all__ = [ "SrcField", "SrcVector", "Ref", + "strip_ptr_ref" ] diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 32b4754f89e32a6e5d63c0e27aece6eea83c0235..c8ac0f4cbc95c6af7f5b283852a344c7cc24cb45 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -174,23 +174,30 @@ class AugExpr: """Create a new `AugExpr` by combining existing expressions.""" return AugExpr().bind(fmt, *deps, **kwdeps) - def bind(self, fmt: str, *deps, **kwdeps): - dependencies: set[SfgVar] = set() - - from pystencils.sympyextensions import is_constant - - for expr in chain(deps, kwdeps.values()): - if isinstance(expr, _ExprLike): - dependencies |= depends(expr) - elif isinstance(expr, sp.Expr) and not is_constant(expr): - raise ValueError( - f"Cannot parse SymPy expression as C++ expression: {expr}\n" - " * pystencils-sfg is currently unable to parse non-constant SymPy expressions " - "since they contain symbols without type information." - ) - - code = fmt.format(*deps, **kwdeps) - self._bind(DependentExpression(code, dependencies)) + def bind(self, fmt: str | AugExpr, *deps, **kwdeps): + if isinstance(fmt, AugExpr): + if bool(deps) or bool(kwdeps): + raise ValueError("Binding to another AugExpr does not permit additional arguments") + if fmt._bound is None: + raise ValueError("Cannot rebind to unbound AugExpr.") + self._bind(fmt._bound) + else: + dependencies: set[SfgVar] = set() + + from pystencils.sympyextensions import is_constant + + for expr in chain(deps, kwdeps.values()): + if isinstance(expr, _ExprLike): + dependencies |= depends(expr) + elif isinstance(expr, sp.Expr) and not is_constant(expr): + raise ValueError( + f"Cannot parse SymPy expression as C++ expression: {expr}\n" + " * pystencils-sfg is currently unable to parse non-constant SymPy expressions " + "since they contain symbols without type information." + ) + + code = fmt.format(*deps, **kwdeps) + self._bind(DependentExpression(code, dependencies)) return self def expr(self) -> DependentExpression: diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py index 6f23160075050c6dfa33fd17636c2a3f826f263a..084f1d529a020b9796aeb82e208579f6f1aa5724 100644 --- a/src/pystencilssfg/lang/types.py +++ b/src/pystencilssfg/lang/types.py @@ -1,5 +1,5 @@ from typing import Any -from pystencils.types import PsType +from pystencils.types import PsType, PsPointerType class Ref(PsType): @@ -24,3 +24,13 @@ class Ref(PsType): def __repr__(self) -> str: return f"Ref({repr(self.base_type)})" + + +def strip_ptr_ref(dtype: PsType): + match dtype: + case Ref(): + return strip_ptr_ref(dtype.base_type) + case PsPointerType(): + return strip_ptr_ref(dtype.base_type) + case _: + return dtype