Skip to content
Snippets Groups Projects
Commit 0fd71fbf authored by Martin Bauer's avatar Martin Bauer
Browse files

Fix bugs recently introduced in topological sort generalizations

parent bd49f37e
No related branches found
No related tags found
No related merge requests found
Pipeline #17098 passed
......@@ -5,7 +5,7 @@ import sympy as sp
from pystencils.assignment import Assignment
from pystencils.simp.simplifications import (
sort_assignments_topologically, sympy_cse_on_assignment_list,
sort_assignments_topologically,
transform_lhs_and_rhs, transform_rhs)
from pystencils.sympyextensions import count_operations, fast_subs
......@@ -85,9 +85,9 @@ class AssignmentCollection:
def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
"""Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition."""
if sort_subexpressions:
self.subexpressions = sympy_cse_on_assignment_list(self.subexpressions)
self.subexpressions = sort_assignments_topologically(self.subexpressions)
if sort_main_assignments:
self.main_assignments = sympy_cse_on_assignment_list(self.main_assignments)
self.main_assignments = sort_assignments_topologically(self.main_assignments)
# ---------------------------------------------- Properties -------------------------------------------------------
......
......@@ -13,12 +13,13 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
edges = []
for c1, e1 in enumerate(assignments):
if isinstance(e1, Assignment):
if hasattr(e1, 'lhs') and hasattr(e1, 'rhs'):
symbols = [e1.lhs]
elif isinstance(e1, Node):
symbols = e1.symbols_defined
else:
symbols = []
raise NotImplementedError("Cannot sort topologically. Object of type " + type(e1) + " cannot be handled.")
for lhs in symbols:
for c2, e2 in enumerate(assignments):
if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
......@@ -155,14 +156,14 @@ def transform_rhs(assignment_list, transformation, *args, **kwargs):
"""Applies a transformation function on the rhs of each element of the passed assignment list
If the list also contains other object, like AST nodes, these are ignored.
Additional parameters are passed to the transformation function"""
return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a
return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a
for a in assignment_list]
def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
return [Assignment(transformation(a.lhs, *args, **kwargs),
transformation(a.rhs, *args, **kwargs))
if isinstance(a, Assignment) else a
if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a
for a in assignment_list]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment