diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index bd906782d6e074955c60221d9a8f4d9b15bae772..4184297d6f3ce383d20c118de176fc11ce4b97a0 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -72,16 +72,24 @@ class SfgClassComposer(SfgComposerMixIn): """ def __init__(self, *params: VarLike): - self._params = tuple(asvar(p) for p in params) + self._params = list(asvar(p) for p in params) self._initializers: list[str] = [] self._body: str | None = None - def init(self, var: VarLike): + def add_param(self, param: VarLike, at: int | None = None): + if at is None: + self._params.append(asvar(param)) + else: + self._params.insert(at, asvar(param)) + + def init(self, var: VarLike | str): """Add an initialization expression to the constructor's initializer list.""" + member = var if isinstance(var, str) else asvar(var) + def init_sequencer(*args: ExprLike): expr = ", ".join(str(arg) for arg in args) - initializer = f"{asvar(var)}{{ {expr} }}" + initializer = f"{member}{{ {expr} }}" self._initializers.append(initializer) return self diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index fa3530bd70f4680b3f3a9bcaa4f95cb333a7f3a6..e0733569a3fa2342cecd114e30968f4c2944a84e 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -301,7 +301,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): else: return expr - def get_shape(coord, symb: SfgKernelParamVar | int): + def get_shape(coord, symb: SfgKernelParamVar | str): expr = self._extraction.size(coord) if expr is None: diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py index 3030e1294d784c88c11e5770c9af0bfe00c21333..6a38d91c5f21a074f09c0a51ca7db39c0e4cba5f 100644 --- a/tests/ir/test_postprocessing.py +++ b/tests/ir/test_postprocessing.py @@ -109,7 +109,7 @@ def test_field_extraction(): khandle = sfg.kernels.create(set_constant) extraction = TestFieldExtraction("f") - call_tree = make_sequence(sfg.map_field(f, extraction), sfg.call(khandle)) + call_tree = make_sequence(sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle)) pp = CallTreePostProcessing() free_vars = pp.get_live_variables(call_tree) @@ -143,8 +143,8 @@ def test_duplicate_field_shapes(): khandle = sfg.kernels.create(set_constant) call_tree = make_sequence( - sfg.map_field(g, TestFieldExtraction("g")), - sfg.map_field(f, TestFieldExtraction("f")), + sfg.map_field(g, TestFieldExtraction("g"), cast_indexing_symbols=False), + sfg.map_field(f, TestFieldExtraction("f"), cast_indexing_symbols=False), sfg.call(khandle), )