From 82997bedc617b08f12fc7c46a4bf0981e5c2fad4 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 18 Nov 2024 16:48:39 +0100
Subject: [PATCH] pystencils API updates & features for sweep gen

 - Fix type printing after changes in pystencils
 - Introduce casting of indexing symbols in field mapping
 - Extend class composer's constructor builder to allow incremental
   building
 - Introduce a utility for stripping pointers and refs from a type

Squashed commit of the following:

commit 6d54f2ca471b07b3b4761d1af5fe03a0267cf27d
Author: Frederik Hennig <frederik.hennig@fau.de>
Date:   Mon Nov 18 16:47:18 2024 +0100

    fix a doctest

commit 2e54c7a022fe6dd0d4f9984090c6c87c4aeae499
Author: Frederik Hennig <frederik.hennig@fau.de>
Date:   Fri Nov 15 15:37:49 2024 +0100

    Fix data type printing

commit 1397bcb25b86815b6bce64cd997ca91747cd4588
Author: Frederik Hennig <frederik.hennig@fau.de>
Date:   Thu Nov 7 14:51:10 2024 +0100

    some minor API changes

commit 2ba2fd8d2914957d183e3087d1cd6e65c3ce546a
Author: Frederik Hennig <frederik.hennig@fau.de>
Date:   Wed Nov 6 15:29:36 2024 +0100

    Add `parameters` property to SfgClassComposer

commit 1a30d20218e40ef9d88d7ae0dd4afac80cb1c96e
Author: Frederik Hennig <frederik.hennig@fau.de>
Date:   Tue Oct 29 17:04:19 2024 +0100

    Extend ConstructorBuilder to allow incremental addition of parameters. Fix test cases for PPing.

commit d0b8fff973dbe71d3c88f2e437a6a2767ae7cb50
Merge: 2977b58 d3e347f
Author: Frederik Hennig <frederik.hennig@fau.de>
Date:   Tue Oct 29 09:20:23 2024 +0100

    Merge branch 'master' into lbwelding-features

commit 2977b58c3c6a71353ee51b4c834692e006ef34a6
Author: Frederik Hennig <frederik.hennig@fau.de>
Date:   Tue Oct 29 09:17:30 2024 +0100

    Introduce casts to indexing symbols in field extraction
---
 src/pystencilssfg/composer/basic_composer.py | 16 +++++---
 src/pystencilssfg/composer/class_composer.py | 19 +++++++--
 src/pystencilssfg/emission/printers.py       |  8 ++--
 src/pystencilssfg/extensions/sycl.py         | 10 ++---
 src/pystencilssfg/ir/postprocessing.py       | 30 ++++++++++----
 src/pystencilssfg/ir/source_components.py    |  4 +-
 src/pystencilssfg/lang/__init__.py           |  3 +-
 src/pystencilssfg/lang/expressions.py        | 43 ++++++++++++--------
 src/pystencilssfg/lang/types.py              | 12 +++++-
 tests/ir/test_postprocessing.py              |  6 +--
 10 files changed, 101 insertions(+), 50 deletions(-)

diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py
index 135f6fb..15177e6 100644
--- a/src/pystencilssfg/composer/basic_composer.py
+++ b/src/pystencilssfg/composer/basic_composer.py
@@ -385,7 +385,7 @@ class SfgBasicComposer(SfgIComposer):
             args_str = ", ".join(str(arg) for arg in args)
             deps: set[SfgVar] = reduce(set.union, (depends(arg) for arg in args), set())
             return SfgStatements(
-                f"{lhs_var.dtype} {lhs_var.name} {{ {args_str} }};",
+                f"{lhs_var.dtype.c_string()} {lhs_var.name} {{ {args_str} }};",
                 (lhs_var,),
                 deps,
             )
@@ -412,7 +412,7 @@ class SfgBasicComposer(SfgIComposer):
             You can look at the expression's dependencies:
 
             >>> sorted(expr.depends, key=lambda v: v.name)
-            [x: float, y: float, z: float]
+            [x: float32, y: float32, z: float32]
 
             If you use an existing expression to create a larger one, the new expression
             inherits all variables from its parts:
@@ -421,7 +421,7 @@ class SfgBasicComposer(SfgIComposer):
             >>> expr2
             x + y * z + w
             >>> sorted(expr2.depends, key=lambda v: v.name)
-            [w: float, x: float, y: float, z: float]
+            [w: float32, x: float32, y: float32, z: float32]
 
         """
         return AugExpr.format(fmt, *deps, **kwdeps)
@@ -446,7 +446,10 @@ class SfgBasicComposer(SfgIComposer):
         return SfgSwitchBuilder(switch_arg)
 
     def map_field(
-        self, field: Field, index_provider: IFieldExtraction | SrcField
+        self,
+        field: Field,
+        index_provider: IFieldExtraction | SrcField,
+        cast_indexing_symbols: bool = True,
     ) -> SfgDeferredFieldMapping:
         """Map a pystencils field to a field data structure, from which pointers, sizes
         and strides should be extracted.
@@ -454,8 +457,11 @@ class SfgBasicComposer(SfgIComposer):
         Args:
             field: The pystencils field to be mapped
             src_object: A `IFieldIndexingProvider` object representing a field data structure.
+            cast_indexing_symbols: Whether to always introduce explicit casts for indexing symbols
         """
-        return SfgDeferredFieldMapping(field, index_provider)
+        return SfgDeferredFieldMapping(
+            field, index_provider, cast_indexing_symbols=cast_indexing_symbols
+        )
 
     def set_param(self, param: VarLike | sp.Symbol, expr: ExprLike):
         deps = depends(expr)
diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py
index bd90678..1f4c486 100644
--- a/src/pystencilssfg/composer/class_composer.py
+++ b/src/pystencilssfg/composer/class_composer.py
@@ -8,6 +8,7 @@ from ..lang import (
     VarLike,
     ExprLike,
     asvar,
+    SfgVar,
 )
 
 from ..ir.source_components import (
@@ -72,16 +73,28 @@ class SfgClassComposer(SfgComposerMixIn):
         """
 
         def __init__(self, *params: VarLike):
-            self._params = tuple(asvar(p) for p in params)
+            self._params = list(asvar(p) for p in params)
             self._initializers: list[str] = []
             self._body: str | None = None
 
-        def init(self, var: VarLike):
+        def add_param(self, param: VarLike, at: int | None = None):
+            if at is None:
+                self._params.append(asvar(param))
+            else:
+                self._params.insert(at, asvar(param))
+
+        @property
+        def parameters(self) -> list[SfgVar]:
+            return self._params
+
+        def init(self, var: VarLike | str):
             """Add an initialization expression to the constructor's initializer list."""
 
+            member = var if isinstance(var, str) else asvar(var)
+
             def init_sequencer(*args: ExprLike):
                 expr = ", ".join(str(arg) for arg in args)
-                initializer = f"{asvar(var)}{{ {expr} }}"
+                initializer = f"{member}{{ {expr} }}"
                 self._initializers.append(initializer)
                 return self
 
diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py
index 9337161..c562bf7 100644
--- a/src/pystencilssfg/emission/printers.py
+++ b/src/pystencilssfg/emission/printers.py
@@ -66,7 +66,7 @@ class SfgGeneralPrinter:
 
     def param_list(self, func: SfgFunction) -> str:
         params = sorted(list(func.parameters), key=lambda p: p.name)
-        return ", ".join(f"{param.dtype} {param.name}" for param in params)
+        return ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params)
 
 
 class SfgHeaderPrinter(SfgGeneralPrinter):
@@ -113,7 +113,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter):
     @visit.case(SfgFunction)
     def function(self, func: SfgFunction):
         params = sorted(list(func.parameters), key=lambda p: p.name)
-        param_list = ", ".join(f"{param.dtype} {param.name}" for param in params)
+        param_list = ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params)
         return f"{func.return_type} {func.name} ( {param_list} );"
 
     @visit.case(SfgClass)
@@ -149,7 +149,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter):
     @visit.case(SfgConstructor)
     def sfg_constructor(self, constr: SfgConstructor):
         code = f"{constr.owning_class.class_name} ("
-        code += ", ".join(f"{param.dtype} {param.name}" for param in constr.parameters)
+        code += ", ".join(f"{param.dtype.c_string()} {param.name}" for param in constr.parameters)
         code += ")\n"
         if constr.initializers:
             code += "  : " + ", ".join(constr.initializers) + "\n"
@@ -161,7 +161,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter):
 
     @visit.case(SfgMemberVariable)
     def sfg_member_var(self, var: SfgMemberVariable):
-        return f"{var.dtype} {var.name};"
+        return f"{var.dtype.c_string()} {var.name};"
 
     @visit.case(SfgMethod)
     def sfg_method(self, method: SfgMethod):
diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py
index 4ee4991..3cb0c1c 100644
--- a/src/pystencilssfg/extensions/sycl.py
+++ b/src/pystencilssfg/extensions/sycl.py
@@ -14,7 +14,7 @@ from ..composer import (
     SfgComposer,
     SfgComposerMixIn,
 )
-from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude, SfgKernelParamVar
+from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude
 from ..ir import (
     SfgCallTreeNode,
     SfgCallTreeLeaf,
@@ -73,7 +73,7 @@ class SyclHandler(AugExpr):
 
         id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>")
 
-        def filter_id(param: SfgKernelParamVar) -> bool:
+        def filter_id(param: SfgVar) -> bool:
             return (
                 isinstance(param.dtype, PsCustomType)
                 and id_regex.search(param.dtype.c_string()) is not None
@@ -117,7 +117,7 @@ class SyclGroup(AugExpr):
 
         id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>")
 
-        def filter_id(param: SfgKernelParamVar) -> bool:
+        def filter_id(param: SfgVar) -> bool:
             return (
                 isinstance(param.dtype, PsCustomType)
                 and id_regex.search(param.dtype.c_string()) is not None
@@ -131,7 +131,7 @@ class SyclGroup(AugExpr):
             comp.map_param(
                 id_param,
                 h_item,
-                f"{id_param.dtype} {id_param.name} = {h_item}.get_local_id();",
+                f"{id_param.dtype.c_string()} {id_param.name} = {h_item}.get_local_id();",
             ),
             SfgKernelCallNode(kernel),
         )
@@ -186,7 +186,7 @@ class SfgLambda:
 
     def get_code(self, ctx: SfgContext):
         captures = ", ".join(self._captures)
-        params = ", ".join(f"{p.dtype} {p.name}" for p in self._params)
+        params = ", ".join(f"{p.dtype.c_string()} {p.name}" for p in self._params)
         body = self._tree.get_code(ctx)
         body = ctx.codestyle.indent(body)
         rtype = (
diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py
index c33ec7a..638a55f 100644
--- a/src/pystencilssfg/ir/postprocessing.py
+++ b/src/pystencilssfg/ir/postprocessing.py
@@ -9,14 +9,14 @@ from abc import ABC, abstractmethod
 import sympy as sp
 
 from pystencils import Field
-from pystencils.types import deconstify
+from pystencils.types import deconstify, PsType
 from pystencils.backend.properties import FieldBasePtr, FieldShape, FieldStride
 
 from ..exceptions import SfgException
 
 from .call_tree import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements
 from ..ir.source_components import SfgKernelParamVar
-from ..lang import SfgVar, IFieldExtraction, SrcField, SrcVector
+from ..lang import SfgVar, IFieldExtraction, SrcField, SrcVector, AugExpr
 
 if TYPE_CHECKING:
     from ..context import SfgContext
@@ -233,20 +233,26 @@ class SfgDeferredParamSetter(SfgDeferredNode):
     def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
         live_var = ppc.get_live_variable(self._lhs.name)
         if live_var is not None:
-            code = f"{live_var.dtype} {live_var.name} = {self._rhs_expr};"
+            code = f"{live_var.dtype.c_string()} {live_var.name} = {self._rhs_expr};"
             return SfgStatements(code, (live_var,), tuple(self._depends))
         else:
             return SfgSequence([])
 
 
 class SfgDeferredFieldMapping(SfgDeferredNode):
-    def __init__(self, psfield: Field, extraction: IFieldExtraction | SrcField):
+    def __init__(
+        self,
+        psfield: Field,
+        extraction: IFieldExtraction | SrcField,
+        cast_indexing_symbols: bool = True,
+    ):
         self._field = psfield
         self._extraction: IFieldExtraction = (
             extraction
             if isinstance(extraction, IFieldExtraction)
             else extraction.get_extraction()
         )
+        self._cast_indexing_symbols = cast_indexing_symbols
 
     def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
         #    Find field pointer
@@ -285,10 +291,16 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
             expr = self._extraction.ptr()
             nodes.append(
                 SfgStatements(
-                    f"{ptr.dtype} {ptr.name} {{ {expr} }};", (ptr,), expr.depends
+                    f"{ptr.dtype.c_string()} {ptr.name} {{ {expr} }};", (ptr,), expr.depends
                 )
             )
 
+        def maybe_cast(expr: AugExpr, target_type: PsType) -> AugExpr:
+            if self._cast_indexing_symbols:
+                return AugExpr(target_type).bind("{}( {} )", deconstify(target_type).c_string(), expr)
+            else:
+                return expr
+
         def get_shape(coord, symb: SfgKernelParamVar | str):
             expr = self._extraction.size(coord)
 
@@ -299,8 +311,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
 
             if isinstance(symb, SfgKernelParamVar) and symb not in done:
                 done.add(symb)
+                expr = maybe_cast(expr, symb.dtype)
                 return SfgStatements(
-                    f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends
+                    f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};", (symb,), expr.depends
                 )
             else:
                 return SfgStatements(f"/* {expr} == {symb} */", (), ())
@@ -315,8 +328,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode):
 
             if isinstance(symb, SfgKernelParamVar) and symb not in done:
                 done.add(symb)
+                expr = maybe_cast(expr, symb.dtype)
                 return SfgStatements(
-                    f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends
+                    f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};", (symb,), expr.depends
                 )
             else:
                 return SfgStatements(f"/* {expr} == {symb} */", (), ())
@@ -341,7 +355,7 @@ class SfgDeferredVectorMapping(SfgDeferredNode):
                 expr = self._vector.extract_component(idx)
                 nodes.append(
                     SfgStatements(
-                        f"{param.dtype} {param.name} {{ {expr} }};",
+                        f"{param.dtype.c_string()} {param.name} {{ {expr} }};",
                         (param,),
                         expr.depends,
                     )
diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py
index 4398938..cf4d103 100644
--- a/src/pystencilssfg/ir/source_components.py
+++ b/src/pystencilssfg/ir/source_components.py
@@ -163,7 +163,7 @@ class SfgKernelHandle:
         self._namespace = namespace
         self._parameters = [SfgKernelParamVar(p) for p in parameters]
 
-        self._scalar_params: set[SfgKernelParamVar] = set()
+        self._scalar_params: set[SfgVar] = set()
         self._fields: set[Field] = set()
 
         for param in self._parameters:
@@ -193,7 +193,7 @@ class SfgKernelHandle:
         return self._parameters
 
     @property
-    def scalar_parameters(self):
+    def scalar_parameters(self) -> set[SfgVar]:
         return self._scalar_params
 
     @property
diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py
index d67ffa0..b5532bf 100644
--- a/src/pystencilssfg/lang/__init__.py
+++ b/src/pystencilssfg/lang/__init__.py
@@ -12,7 +12,7 @@ from .expressions import (
     SrcVector,
 )
 
-from .types import Ref
+from .types import Ref, strip_ptr_ref
 
 __all__ = [
     "SfgVar",
@@ -27,4 +27,5 @@ __all__ = [
     "SrcField",
     "SrcVector",
     "Ref",
+    "strip_ptr_ref"
 ]
diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py
index 481922e..c8ac0f4 100644
--- a/src/pystencilssfg/lang/expressions.py
+++ b/src/pystencilssfg/lang/expressions.py
@@ -174,23 +174,30 @@ class AugExpr:
         """Create a new `AugExpr` by combining existing expressions."""
         return AugExpr().bind(fmt, *deps, **kwdeps)
 
-    def bind(self, fmt: str, *deps, **kwdeps):
-        dependencies: set[SfgVar] = set()
-
-        from pystencils.sympyextensions import is_constant
-
-        for expr in chain(deps, kwdeps.values()):
-            if isinstance(expr, _ExprLike):
-                dependencies |= depends(expr)
-            elif isinstance(expr, sp.Expr) and not is_constant(expr):
-                raise ValueError(
-                    f"Cannot parse SymPy expression as C++ expression: {expr}\n"
-                    "  * pystencils-sfg is currently unable to parse non-constant SymPy expressions "
-                    "since they contain symbols without type information."
-                )
-
-        code = fmt.format(*deps, **kwdeps)
-        self._bind(DependentExpression(code, dependencies))
+    def bind(self, fmt: str | AugExpr, *deps, **kwdeps):
+        if isinstance(fmt, AugExpr):
+            if bool(deps) or bool(kwdeps):
+                raise ValueError("Binding to another AugExpr does not permit additional arguments")
+            if fmt._bound is None:
+                raise ValueError("Cannot rebind to unbound AugExpr.")
+            self._bind(fmt._bound)
+        else:
+            dependencies: set[SfgVar] = set()
+
+            from pystencils.sympyextensions import is_constant
+
+            for expr in chain(deps, kwdeps.values()):
+                if isinstance(expr, _ExprLike):
+                    dependencies |= depends(expr)
+                elif isinstance(expr, sp.Expr) and not is_constant(expr):
+                    raise ValueError(
+                        f"Cannot parse SymPy expression as C++ expression: {expr}\n"
+                        "  * pystencils-sfg is currently unable to parse non-constant SymPy expressions "
+                        "since they contain symbols without type information."
+                    )
+
+            code = fmt.format(*deps, **kwdeps)
+            self._bind(DependentExpression(code, dependencies))
         return self
 
     def expr(self) -> DependentExpression:
@@ -251,7 +258,7 @@ class AugExpr:
         self._bound = expr
         return self
 
-    def _is_bound(self) -> bool:
+    def is_bound(self) -> bool:
         return self._bound is not None
 
 
diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py
index 6f23160..084f1d5 100644
--- a/src/pystencilssfg/lang/types.py
+++ b/src/pystencilssfg/lang/types.py
@@ -1,5 +1,5 @@
 from typing import Any
-from pystencils.types import PsType
+from pystencils.types import PsType, PsPointerType
 
 
 class Ref(PsType):
@@ -24,3 +24,13 @@ class Ref(PsType):
 
     def __repr__(self) -> str:
         return f"Ref({repr(self.base_type)})"
+
+
+def strip_ptr_ref(dtype: PsType):
+    match dtype:
+        case Ref():
+            return strip_ptr_ref(dtype.base_type)
+        case PsPointerType():
+            return strip_ptr_ref(dtype.base_type)
+        case _:
+            return dtype
diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py
index 3030e12..6a38d91 100644
--- a/tests/ir/test_postprocessing.py
+++ b/tests/ir/test_postprocessing.py
@@ -109,7 +109,7 @@ def test_field_extraction():
     khandle = sfg.kernels.create(set_constant)
 
     extraction = TestFieldExtraction("f")
-    call_tree = make_sequence(sfg.map_field(f, extraction), sfg.call(khandle))
+    call_tree = make_sequence(sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle))
 
     pp = CallTreePostProcessing()
     free_vars = pp.get_live_variables(call_tree)
@@ -143,8 +143,8 @@ def test_duplicate_field_shapes():
     khandle = sfg.kernels.create(set_constant)
 
     call_tree = make_sequence(
-        sfg.map_field(g, TestFieldExtraction("g")),
-        sfg.map_field(f, TestFieldExtraction("f")),
+        sfg.map_field(g, TestFieldExtraction("g"), cast_indexing_symbols=False),
+        sfg.map_field(f, TestFieldExtraction("f"), cast_indexing_symbols=False),
         sfg.call(khandle),
     )
 
-- 
GitLab