diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index 9d253ff7b01dcdbeb2c3591de4d6eeb0e495cc74..e0f5ec926376205f7f5ed68650791e75b1b634da 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -5,7 +5,7 @@ import sympy as sp from pystencils.assignment import Assignment from pystencils.simp.simplifications import ( - sort_assignments_topologically, sympy_cse_on_assignment_list, + sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs) from pystencils.sympyextensions import count_operations, fast_subs @@ -85,9 +85,9 @@ class AssignmentCollection: def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None: """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition.""" if sort_subexpressions: - self.subexpressions = sympy_cse_on_assignment_list(self.subexpressions) + self.subexpressions = sort_assignments_topologically(self.subexpressions) if sort_main_assignments: - self.main_assignments = sympy_cse_on_assignment_list(self.main_assignments) + self.main_assignments = sort_assignments_topologically(self.main_assignments) # ---------------------------------------------- Properties ------------------------------------------------------- diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py index ab2b3d83df89a4bf3f4f3fe5af03bcc744fae1d7..3a4c64764e3c17d46c204e6a39abeab5b3d13439 100644 --- a/pystencils/simp/simplifications.py +++ b/pystencils/simp/simplifications.py @@ -13,12 +13,13 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node] """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs""" edges = [] for c1, e1 in enumerate(assignments): - if isinstance(e1, Assignment): + if hasattr(e1, 'lhs') and hasattr(e1, 'rhs'): symbols = [e1.lhs] elif isinstance(e1, Node): symbols = e1.symbols_defined else: - symbols = [] + raise NotImplementedError("Cannot sort topologically. Object of type " + type(e1) + " cannot be handled.") + for lhs in symbols: for c2, e2 in enumerate(assignments): if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols: @@ -155,14 +156,14 @@ def transform_rhs(assignment_list, transformation, *args, **kwargs): """Applies a transformation function on the rhs of each element of the passed assignment list If the list also contains other object, like AST nodes, these are ignored. Additional parameters are passed to the transformation function""" - return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a + return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a for a in assignment_list] def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs): return [Assignment(transformation(a.lhs, *args, **kwargs), transformation(a.rhs, *args, **kwargs)) - if isinstance(a, Assignment) else a + if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a for a in assignment_list]