From 1397bcb25b86815b6bce64cd997ca91747cd4588 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 7 Nov 2024 14:51:10 +0100
Subject: [PATCH] some minor API changes

---
 src/pystencilssfg/extensions/sycl.py      |  6 ++--
 src/pystencilssfg/ir/source_components.py |  4 +--
 src/pystencilssfg/lang/__init__.py        |  3 +-
 src/pystencilssfg/lang/expressions.py     | 41 +++++++++++++----------
 src/pystencilssfg/lang/types.py           | 12 ++++++-
 5 files changed, 42 insertions(+), 24 deletions(-)

diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py
index 4ee4991..dc80202 100644
--- a/src/pystencilssfg/extensions/sycl.py
+++ b/src/pystencilssfg/extensions/sycl.py
@@ -14,7 +14,7 @@ from ..composer import (
     SfgComposer,
     SfgComposerMixIn,
 )
-from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude, SfgKernelParamVar
+from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude
 from ..ir import (
     SfgCallTreeNode,
     SfgCallTreeLeaf,
@@ -73,7 +73,7 @@ class SyclHandler(AugExpr):
 
         id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>")
 
-        def filter_id(param: SfgKernelParamVar) -> bool:
+        def filter_id(param: SfgVar) -> bool:
             return (
                 isinstance(param.dtype, PsCustomType)
                 and id_regex.search(param.dtype.c_string()) is not None
@@ -117,7 +117,7 @@ class SyclGroup(AugExpr):
 
         id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>")
 
-        def filter_id(param: SfgKernelParamVar) -> bool:
+        def filter_id(param: SfgVar) -> bool:
             return (
                 isinstance(param.dtype, PsCustomType)
                 and id_regex.search(param.dtype.c_string()) is not None
diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py
index 4398938..cf4d103 100644
--- a/src/pystencilssfg/ir/source_components.py
+++ b/src/pystencilssfg/ir/source_components.py
@@ -163,7 +163,7 @@ class SfgKernelHandle:
         self._namespace = namespace
         self._parameters = [SfgKernelParamVar(p) for p in parameters]
 
-        self._scalar_params: set[SfgKernelParamVar] = set()
+        self._scalar_params: set[SfgVar] = set()
         self._fields: set[Field] = set()
 
         for param in self._parameters:
@@ -193,7 +193,7 @@ class SfgKernelHandle:
         return self._parameters
 
     @property
-    def scalar_parameters(self):
+    def scalar_parameters(self) -> set[SfgVar]:
         return self._scalar_params
 
     @property
diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py
index d67ffa0..b5532bf 100644
--- a/src/pystencilssfg/lang/__init__.py
+++ b/src/pystencilssfg/lang/__init__.py
@@ -12,7 +12,7 @@ from .expressions import (
     SrcVector,
 )
 
-from .types import Ref
+from .types import Ref, strip_ptr_ref
 
 __all__ = [
     "SfgVar",
@@ -27,4 +27,5 @@ __all__ = [
     "SrcField",
     "SrcVector",
     "Ref",
+    "strip_ptr_ref"
 ]
diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py
index 32b4754..c8ac0f4 100644
--- a/src/pystencilssfg/lang/expressions.py
+++ b/src/pystencilssfg/lang/expressions.py
@@ -174,23 +174,30 @@ class AugExpr:
         """Create a new `AugExpr` by combining existing expressions."""
         return AugExpr().bind(fmt, *deps, **kwdeps)
 
-    def bind(self, fmt: str, *deps, **kwdeps):
-        dependencies: set[SfgVar] = set()
-
-        from pystencils.sympyextensions import is_constant
-
-        for expr in chain(deps, kwdeps.values()):
-            if isinstance(expr, _ExprLike):
-                dependencies |= depends(expr)
-            elif isinstance(expr, sp.Expr) and not is_constant(expr):
-                raise ValueError(
-                    f"Cannot parse SymPy expression as C++ expression: {expr}\n"
-                    "  * pystencils-sfg is currently unable to parse non-constant SymPy expressions "
-                    "since they contain symbols without type information."
-                )
-
-        code = fmt.format(*deps, **kwdeps)
-        self._bind(DependentExpression(code, dependencies))
+    def bind(self, fmt: str | AugExpr, *deps, **kwdeps):
+        if isinstance(fmt, AugExpr):
+            if bool(deps) or bool(kwdeps):
+                raise ValueError("Binding to another AugExpr does not permit additional arguments")
+            if fmt._bound is None:
+                raise ValueError("Cannot rebind to unbound AugExpr.")
+            self._bind(fmt._bound)
+        else:
+            dependencies: set[SfgVar] = set()
+
+            from pystencils.sympyextensions import is_constant
+
+            for expr in chain(deps, kwdeps.values()):
+                if isinstance(expr, _ExprLike):
+                    dependencies |= depends(expr)
+                elif isinstance(expr, sp.Expr) and not is_constant(expr):
+                    raise ValueError(
+                        f"Cannot parse SymPy expression as C++ expression: {expr}\n"
+                        "  * pystencils-sfg is currently unable to parse non-constant SymPy expressions "
+                        "since they contain symbols without type information."
+                    )
+
+            code = fmt.format(*deps, **kwdeps)
+            self._bind(DependentExpression(code, dependencies))
         return self
 
     def expr(self) -> DependentExpression:
diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py
index 6f23160..084f1d5 100644
--- a/src/pystencilssfg/lang/types.py
+++ b/src/pystencilssfg/lang/types.py
@@ -1,5 +1,5 @@
 from typing import Any
-from pystencils.types import PsType
+from pystencils.types import PsType, PsPointerType
 
 
 class Ref(PsType):
@@ -24,3 +24,13 @@ class Ref(PsType):
 
     def __repr__(self) -> str:
         return f"Ref({repr(self.base_type)})"
+
+
+def strip_ptr_ref(dtype: PsType):
+    match dtype:
+        case Ref():
+            return strip_ptr_ref(dtype.base_type)
+        case PsPointerType():
+            return strip_ptr_ref(dtype.base_type)
+        case _:
+            return dtype
-- 
GitLab