From 8d860b26e10431a995487265ca99bbc329bbf933 Mon Sep 17 00:00:00 2001
From: Daniel Bauer <daniel.j.bauer@fau.de>
Date: Thu, 16 Jan 2025 08:52:45 +0100
Subject: [PATCH] Expose method to retrieve the original name of a duplicated
 symbol

---
 .../backend/kernelcreation/context.py          | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py
index 39fb8ef6d..8f5931c64 100644
--- a/src/pystencils/backend/kernelcreation/context.py
+++ b/src/pystencils/backend/kernelcreation/context.py
@@ -126,7 +126,7 @@ class KernelCreationContext:
             symb.apply_dtype(dtype)
 
         return symb
-    
+
     def get_new_symbol(self, name: str, dtype: PsType | None = None) -> PsSymbol:
         """Always create a new symbol, deduplicating its name if another symbol with the same name already exists."""
 
@@ -173,15 +173,11 @@ class KernelCreationContext:
     ) -> PsSymbol:
         """Canonically duplicates the given symbol.
 
-        A new symbol with the new name ``symb.name + "__<counter>"`` and optionally a different data type 
+        A new symbol with the new name ``symb.name + "__<counter>"`` and optionally a different data type
         is created, added to the symbol table, and returned.
         The ``counter`` reflects the number of previously created duplicates of this symbol.
         """
-        if (result := self._symbol_ctr_pattern.search(symb.name)) is not None:
-            span = result.span()
-            basename = symb.name[: span[0]]
-        else:
-            basename = symb.name
+        basename = self.basename(symb)
 
         if new_dtype is None:
             new_dtype = symb.dtype
@@ -194,6 +190,14 @@ class KernelCreationContext:
                 return self.get_symbol(dup_name, new_dtype)
         assert False, "unreachable code"
 
+    def basename(self, symb: PsSymbol) -> str:
+        """Returns the original name a symbol had before duplication."""
+        if (result := self._symbol_ctr_pattern.search(symb.name)) is not None:
+            span = result.span()
+            return symb.name[: span[0]]
+        else:
+            return symb.name
+
     @property
     def symbols(self) -> Iterable[PsSymbol]:
         """Return an iterable of all symbols listed in the symbol table."""
-- 
GitLab