From 8d860b26e10431a995487265ca99bbc329bbf933 Mon Sep 17 00:00:00 2001 From: Daniel Bauer <daniel.j.bauer@fau.de> Date: Thu, 16 Jan 2025 08:52:45 +0100 Subject: [PATCH] Expose method to retrieve the original name of a duplicated symbol --- .../backend/kernelcreation/context.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 39fb8ef6d..8f5931c64 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -126,7 +126,7 @@ class KernelCreationContext: symb.apply_dtype(dtype) return symb - + def get_new_symbol(self, name: str, dtype: PsType | None = None) -> PsSymbol: """Always create a new symbol, deduplicating its name if another symbol with the same name already exists.""" @@ -173,15 +173,11 @@ class KernelCreationContext: ) -> PsSymbol: """Canonically duplicates the given symbol. - A new symbol with the new name ``symb.name + "__<counter>"`` and optionally a different data type + A new symbol with the new name ``symb.name + "__<counter>"`` and optionally a different data type is created, added to the symbol table, and returned. The ``counter`` reflects the number of previously created duplicates of this symbol. """ - if (result := self._symbol_ctr_pattern.search(symb.name)) is not None: - span = result.span() - basename = symb.name[: span[0]] - else: - basename = symb.name + basename = self.basename(symb) if new_dtype is None: new_dtype = symb.dtype @@ -194,6 +190,14 @@ class KernelCreationContext: return self.get_symbol(dup_name, new_dtype) assert False, "unreachable code" + def basename(self, symb: PsSymbol) -> str: + """Returns the original name a symbol had before duplication.""" + if (result := self._symbol_ctr_pattern.search(symb.name)) is not None: + span = result.span() + return symb.name[: span[0]] + else: + return symb.name + @property def symbols(self) -> Iterable[PsSymbol]: """Return an iterable of all symbols listed in the symbol table.""" -- GitLab