From 36112534ddf534c7b75bd9c738017b2adb7f4cd2 Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Thu, 20 Jul 2023 13:35:53 +0200
Subject: [PATCH] Fix symbol counters

---
 pystencils/simp/assignment_collection.py | 9 ++++++---
 pystencils/typing/types.py               | 4 ++--
 2 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py
index 49fc06e2d..b0c09cec9 100644
--- a/pystencils/simp/assignment_collection.py
+++ b/pystencils/simp/assignment_collection.py
@@ -61,8 +61,11 @@ class AssignmentCollection:
 
         self.simplification_hints = simplification_hints
 
+        ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name]
+        max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0
+
         if subexpression_symbol_generator is None:
-            self.subexpression_symbol_generator = SymbolGen()
+            self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr)
         else:
             self.subexpression_symbol_generator = subexpression_symbol_generator
 
@@ -453,8 +456,8 @@ class AssignmentCollection:
 class SymbolGen:
     """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
 
-    def __init__(self, symbol="xi", dtype=None):
-        self._ctr = 0
+    def __init__(self, symbol="xi", dtype=None, ctr=0):
+        self._ctr = ctr
         self._symbol = symbol
         self._dtype = dtype
 
diff --git a/pystencils/typing/types.py b/pystencils/typing/types.py
index 531a8e290..f0f9744a5 100644
--- a/pystencils/typing/types.py
+++ b/pystencils/typing/types.py
@@ -70,7 +70,7 @@ class BasicType(AbstractType):
     BasicType is defined with a const qualifier and a np.dtype.
     """
 
-    def __init__(self, dtype: Union[np.dtype, 'BasicType', str], const: bool = False):
+    def __init__(self, dtype: Union[type, 'BasicType', str], const: bool = False):
         if isinstance(dtype, BasicType):
             self.numpy_dtype = dtype.numpy_dtype
             self.const = dtype.const
@@ -291,7 +291,7 @@ class StructType(AbstractType):
         return hash((self.numpy_dtype, self.const))
 
 
-def create_type(specification: Union[np.dtype, AbstractType, str]) -> AbstractType:
+def create_type(specification: Union[type, AbstractType, str]) -> AbstractType:
     # TODO: Deprecated Use the constructor of BasicType or StructType instead
     """Creates a subclass of Type according to a string or an object of subclass Type.
 
-- 
GitLab