From 1a30d20218e40ef9d88d7ae0dd4afac80cb1c96e Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 29 Oct 2024 17:04:19 +0100 Subject: [PATCH] Extend ConstructorBuilder to allow incremental addition of parameters. Fix test cases for PPing. --- src/pystencilssfg/composer/class_composer.py | 14 +++++++++++--- src/pystencilssfg/ir/postprocessing.py | 2 +- tests/ir/test_postprocessing.py | 6 +++--- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index bd90678..4184297 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -72,16 +72,24 @@ class SfgClassComposer(SfgComposerMixIn): """ def __init__(self, *params: VarLike): - self._params = tuple(asvar(p) for p in params) + self._params = list(asvar(p) for p in params) self._initializers: list[str] = [] self._body: str | None = None - def init(self, var: VarLike): + def add_param(self, param: VarLike, at: int | None = None): + if at is None: + self._params.append(asvar(param)) + else: + self._params.insert(at, asvar(param)) + + def init(self, var: VarLike | str): """Add an initialization expression to the constructor's initializer list.""" + member = var if isinstance(var, str) else asvar(var) + def init_sequencer(*args: ExprLike): expr = ", ".join(str(arg) for arg in args) - initializer = f"{asvar(var)}{{ {expr} }}" + initializer = f"{member}{{ {expr} }}" self._initializers.append(initializer) return self diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index fa3530b..e073356 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -301,7 +301,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): else: return expr - def get_shape(coord, symb: SfgKernelParamVar | int): + def get_shape(coord, symb: SfgKernelParamVar | str): expr = self._extraction.size(coord) if expr is None: diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py index 3030e12..6a38d91 100644 --- a/tests/ir/test_postprocessing.py +++ b/tests/ir/test_postprocessing.py @@ -109,7 +109,7 @@ def test_field_extraction(): khandle = sfg.kernels.create(set_constant) extraction = TestFieldExtraction("f") - call_tree = make_sequence(sfg.map_field(f, extraction), sfg.call(khandle)) + call_tree = make_sequence(sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle)) pp = CallTreePostProcessing() free_vars = pp.get_live_variables(call_tree) @@ -143,8 +143,8 @@ def test_duplicate_field_shapes(): khandle = sfg.kernels.create(set_constant) call_tree = make_sequence( - sfg.map_field(g, TestFieldExtraction("g")), - sfg.map_field(f, TestFieldExtraction("f")), + sfg.map_field(g, TestFieldExtraction("g"), cast_indexing_symbols=False), + sfg.map_field(f, TestFieldExtraction("f"), cast_indexing_symbols=False), sfg.call(khandle), ) -- GitLab