Skip to content
Snippets Groups Projects

Expose method to retrieve the original name of a duplicated symbol

Merged Daniel Bauer requested to merge hyteg/pystencils:bauerd/basename into v2.0-dev
1 file
+ 11
7
Compare changes
  • Side-by-side
  • Inline
@@ -126,7 +126,7 @@ class KernelCreationContext:
@@ -126,7 +126,7 @@ class KernelCreationContext:
symb.apply_dtype(dtype)
symb.apply_dtype(dtype)
return symb
return symb
def get_new_symbol(self, name: str, dtype: PsType | None = None) -> PsSymbol:
def get_new_symbol(self, name: str, dtype: PsType | None = None) -> PsSymbol:
"""Always create a new symbol, deduplicating its name if another symbol with the same name already exists."""
"""Always create a new symbol, deduplicating its name if another symbol with the same name already exists."""
@@ -173,15 +173,11 @@ class KernelCreationContext:
@@ -173,15 +173,11 @@ class KernelCreationContext:
) -> PsSymbol:
) -> PsSymbol:
"""Canonically duplicates the given symbol.
"""Canonically duplicates the given symbol.
A new symbol with the new name ``symb.name + "__<counter>"`` and optionally a different data type
A new symbol with the new name ``symb.name + "__<counter>"`` and optionally a different data type
is created, added to the symbol table, and returned.
is created, added to the symbol table, and returned.
The ``counter`` reflects the number of previously created duplicates of this symbol.
The ``counter`` reflects the number of previously created duplicates of this symbol.
"""
"""
if (result := self._symbol_ctr_pattern.search(symb.name)) is not None:
basename = self.basename(symb)
span = result.span()
basename = symb.name[: span[0]]
else:
basename = symb.name
if new_dtype is None:
if new_dtype is None:
new_dtype = symb.dtype
new_dtype = symb.dtype
@@ -194,6 +190,14 @@ class KernelCreationContext:
@@ -194,6 +190,14 @@ class KernelCreationContext:
return self.get_symbol(dup_name, new_dtype)
return self.get_symbol(dup_name, new_dtype)
assert False, "unreachable code"
assert False, "unreachable code"
 
def basename(self, symb: PsSymbol) -> str:
 
"""Returns the original name a symbol had before duplication."""
 
if (result := self._symbol_ctr_pattern.search(symb.name)) is not None:
 
span = result.span()
 
return symb.name[: span[0]]
 
else:
 
return symb.name
 
@property
@property
def symbols(self) -> Iterable[PsSymbol]:
def symbols(self) -> Iterable[PsSymbol]:
"""Return an iterable of all symbols listed in the symbol table."""
"""Return an iterable of all symbols listed in the symbol table."""
Loading