diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 39fb8ef6dac855553b7e18d2a688c67ca45fb227..8f5931c6494bbf1eca950e38df0053e79af3e81b 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -126,7 +126,7 @@ class KernelCreationContext: symb.apply_dtype(dtype) return symb - + def get_new_symbol(self, name: str, dtype: PsType | None = None) -> PsSymbol: """Always create a new symbol, deduplicating its name if another symbol with the same name already exists.""" @@ -173,15 +173,11 @@ class KernelCreationContext: ) -> PsSymbol: """Canonically duplicates the given symbol. - A new symbol with the new name ``symb.name + "__<counter>"`` and optionally a different data type + A new symbol with the new name ``symb.name + "__<counter>"`` and optionally a different data type is created, added to the symbol table, and returned. The ``counter`` reflects the number of previously created duplicates of this symbol. """ - if (result := self._symbol_ctr_pattern.search(symb.name)) is not None: - span = result.span() - basename = symb.name[: span[0]] - else: - basename = symb.name + basename = self.basename(symb) if new_dtype is None: new_dtype = symb.dtype @@ -194,6 +190,14 @@ class KernelCreationContext: return self.get_symbol(dup_name, new_dtype) assert False, "unreachable code" + def basename(self, symb: PsSymbol) -> str: + """Returns the original name a symbol had before duplication.""" + if (result := self._symbol_ctr_pattern.search(symb.name)) is not None: + span = result.span() + return symb.name[: span[0]] + else: + return symb.name + @property def symbols(self) -> Iterable[PsSymbol]: """Return an iterable of all symbols listed in the symbol table."""