Skip to content
Snippets Groups Projects
Commit 0fd71fbf authored by Martin Bauer's avatar Martin Bauer
Browse files

Fix bugs recently introduced in topological sort generalizations

parent bd49f37e
Branches
No related tags found
No related merge requests found
...@@ -5,7 +5,7 @@ import sympy as sp ...@@ -5,7 +5,7 @@ import sympy as sp
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.simp.simplifications import ( from pystencils.simp.simplifications import (
sort_assignments_topologically, sympy_cse_on_assignment_list, sort_assignments_topologically,
transform_lhs_and_rhs, transform_rhs) transform_lhs_and_rhs, transform_rhs)
from pystencils.sympyextensions import count_operations, fast_subs from pystencils.sympyextensions import count_operations, fast_subs
...@@ -85,9 +85,9 @@ class AssignmentCollection: ...@@ -85,9 +85,9 @@ class AssignmentCollection:
def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None: def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
"""Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition.""" """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition."""
if sort_subexpressions: if sort_subexpressions:
self.subexpressions = sympy_cse_on_assignment_list(self.subexpressions) self.subexpressions = sort_assignments_topologically(self.subexpressions)
if sort_main_assignments: if sort_main_assignments:
self.main_assignments = sympy_cse_on_assignment_list(self.main_assignments) self.main_assignments = sort_assignments_topologically(self.main_assignments)
# ---------------------------------------------- Properties ------------------------------------------------------- # ---------------------------------------------- Properties -------------------------------------------------------
......
...@@ -13,12 +13,13 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node] ...@@ -13,12 +13,13 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs""" """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
edges = [] edges = []
for c1, e1 in enumerate(assignments): for c1, e1 in enumerate(assignments):
if isinstance(e1, Assignment): if hasattr(e1, 'lhs') and hasattr(e1, 'rhs'):
symbols = [e1.lhs] symbols = [e1.lhs]
elif isinstance(e1, Node): elif isinstance(e1, Node):
symbols = e1.symbols_defined symbols = e1.symbols_defined
else: else:
symbols = [] raise NotImplementedError("Cannot sort topologically. Object of type " + type(e1) + " cannot be handled.")
for lhs in symbols: for lhs in symbols:
for c2, e2 in enumerate(assignments): for c2, e2 in enumerate(assignments):
if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols: if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
...@@ -155,14 +156,14 @@ def transform_rhs(assignment_list, transformation, *args, **kwargs): ...@@ -155,14 +156,14 @@ def transform_rhs(assignment_list, transformation, *args, **kwargs):
"""Applies a transformation function on the rhs of each element of the passed assignment list """Applies a transformation function on the rhs of each element of the passed assignment list
If the list also contains other object, like AST nodes, these are ignored. If the list also contains other object, like AST nodes, these are ignored.
Additional parameters are passed to the transformation function""" Additional parameters are passed to the transformation function"""
return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a
for a in assignment_list] for a in assignment_list]
def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs): def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
return [Assignment(transformation(a.lhs, *args, **kwargs), return [Assignment(transformation(a.lhs, *args, **kwargs),
transformation(a.rhs, *args, **kwargs)) transformation(a.rhs, *args, **kwargs))
if isinstance(a, Assignment) else a if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a
for a in assignment_list] for a in assignment_list]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment