From 1397bcb25b86815b6bce64cd997ca91747cd4588 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 7 Nov 2024 14:51:10 +0100 Subject: [PATCH] some minor API changes --- src/pystencilssfg/extensions/sycl.py | 6 ++-- src/pystencilssfg/ir/source_components.py | 4 +-- src/pystencilssfg/lang/__init__.py | 3 +- src/pystencilssfg/lang/expressions.py | 41 +++++++++++++---------- src/pystencilssfg/lang/types.py | 12 ++++++- 5 files changed, 42 insertions(+), 24 deletions(-) diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index 4ee4991..dc80202 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -14,7 +14,7 @@ from ..composer import ( SfgComposer, SfgComposerMixIn, ) -from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude, SfgKernelParamVar +from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude from ..ir import ( SfgCallTreeNode, SfgCallTreeLeaf, @@ -73,7 +73,7 @@ class SyclHandler(AugExpr): id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>") - def filter_id(param: SfgKernelParamVar) -> bool: + def filter_id(param: SfgVar) -> bool: return ( isinstance(param.dtype, PsCustomType) and id_regex.search(param.dtype.c_string()) is not None @@ -117,7 +117,7 @@ class SyclGroup(AugExpr): id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>") - def filter_id(param: SfgKernelParamVar) -> bool: + def filter_id(param: SfgVar) -> bool: return ( isinstance(param.dtype, PsCustomType) and id_regex.search(param.dtype.c_string()) is not None diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index 4398938..cf4d103 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -163,7 +163,7 @@ class SfgKernelHandle: self._namespace = namespace self._parameters = [SfgKernelParamVar(p) for p in parameters] - self._scalar_params: set[SfgKernelParamVar] = set() + self._scalar_params: set[SfgVar] = set() self._fields: set[Field] = set() for param in self._parameters: @@ -193,7 +193,7 @@ class SfgKernelHandle: return self._parameters @property - def scalar_parameters(self): + def scalar_parameters(self) -> set[SfgVar]: return self._scalar_params @property diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index d67ffa0..b5532bf 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -12,7 +12,7 @@ from .expressions import ( SrcVector, ) -from .types import Ref +from .types import Ref, strip_ptr_ref __all__ = [ "SfgVar", @@ -27,4 +27,5 @@ __all__ = [ "SrcField", "SrcVector", "Ref", + "strip_ptr_ref" ] diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 32b4754..c8ac0f4 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -174,23 +174,30 @@ class AugExpr: """Create a new `AugExpr` by combining existing expressions.""" return AugExpr().bind(fmt, *deps, **kwdeps) - def bind(self, fmt: str, *deps, **kwdeps): - dependencies: set[SfgVar] = set() - - from pystencils.sympyextensions import is_constant - - for expr in chain(deps, kwdeps.values()): - if isinstance(expr, _ExprLike): - dependencies |= depends(expr) - elif isinstance(expr, sp.Expr) and not is_constant(expr): - raise ValueError( - f"Cannot parse SymPy expression as C++ expression: {expr}\n" - " * pystencils-sfg is currently unable to parse non-constant SymPy expressions " - "since they contain symbols without type information." - ) - - code = fmt.format(*deps, **kwdeps) - self._bind(DependentExpression(code, dependencies)) + def bind(self, fmt: str | AugExpr, *deps, **kwdeps): + if isinstance(fmt, AugExpr): + if bool(deps) or bool(kwdeps): + raise ValueError("Binding to another AugExpr does not permit additional arguments") + if fmt._bound is None: + raise ValueError("Cannot rebind to unbound AugExpr.") + self._bind(fmt._bound) + else: + dependencies: set[SfgVar] = set() + + from pystencils.sympyextensions import is_constant + + for expr in chain(deps, kwdeps.values()): + if isinstance(expr, _ExprLike): + dependencies |= depends(expr) + elif isinstance(expr, sp.Expr) and not is_constant(expr): + raise ValueError( + f"Cannot parse SymPy expression as C++ expression: {expr}\n" + " * pystencils-sfg is currently unable to parse non-constant SymPy expressions " + "since they contain symbols without type information." + ) + + code = fmt.format(*deps, **kwdeps) + self._bind(DependentExpression(code, dependencies)) return self def expr(self) -> DependentExpression: diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py index 6f23160..084f1d5 100644 --- a/src/pystencilssfg/lang/types.py +++ b/src/pystencilssfg/lang/types.py @@ -1,5 +1,5 @@ from typing import Any -from pystencils.types import PsType +from pystencils.types import PsType, PsPointerType class Ref(PsType): @@ -24,3 +24,13 @@ class Ref(PsType): def __repr__(self) -> str: return f"Ref({repr(self.base_type)})" + + +def strip_ptr_ref(dtype: PsType): + match dtype: + case Ref(): + return strip_ptr_ref(dtype.base_type) + case PsPointerType(): + return strip_ptr_ref(dtype.base_type) + case _: + return dtype -- GitLab