Skip to content
Snippets Groups Projects
Commit 63a9140d authored by Martin Bauer's avatar Martin Bauer
Browse files

Fixes in assignments collection simplifications / topological sort

parent 0fd71fbf
No related branches found
No related tags found
No related merge requests found
Pipeline #17176 passed
...@@ -197,14 +197,16 @@ class AssignmentCollection: ...@@ -197,14 +197,16 @@ class AssignmentCollection:
return res return res
def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False, def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False,
substitute_on_lhs: bool = True) -> 'AssignmentCollection': substitute_on_lhs: bool = True,
sort_topologically: bool = True) -> 'AssignmentCollection':
"""Returns new object, where terms are substituted according to the passed substitution dict. """Returns new object, where terms are substituted according to the passed substitution dict.
Args: Args:
substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions
add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions
substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments
sort_topologically: if subexpressions are added as substitutions and this parameters is true,
the subexpressions are sorted topologically after insertion
Returns: Returns:
New AssignmentCollection where substitutions have been applied, self is not altered. New AssignmentCollection where substitutions have been applied, self is not altered.
""" """
...@@ -215,7 +217,8 @@ class AssignmentCollection: ...@@ -215,7 +217,8 @@ class AssignmentCollection:
if add_substitutions_as_subexpressions: if add_substitutions_as_subexpressions:
transformed_subexpressions = [Assignment(b, a) for a, b in transformed_subexpressions = [Assignment(b, a) for a, b in
substitutions.items()] + transformed_subexpressions substitutions.items()] + transformed_subexpressions
transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions) if sort_topologically:
transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)
return self.copy(transformed_assignments, transformed_subexpressions) return self.copy(transformed_assignments, transformed_subexpressions)
def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection': def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
......
...@@ -104,7 +104,7 @@ def add_subexpressions_for_divisions(ac): ...@@ -104,7 +104,7 @@ def add_subexpressions_for_divisions(ac):
divisors = sorted(list(divisors), key=lambda x: str(x)) divisors = sorted(list(divisors), key=lambda x: str(x))
new_symbol_gen = ac.subexpression_symbol_generator new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)} substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}
return ac.new_with_substitutions(substitutions, True) return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)
def add_subexpressions_for_sums(ac): def add_subexpressions_for_sums(ac):
...@@ -142,14 +142,18 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments ...@@ -142,14 +142,18 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments
then the new values are computed and written to the same field in-place. then the new values are computed and written to the same field in-place.
""" """
field_reads = set() field_reads = set()
to_iterate = []
if subexpressions: if subexpressions:
for assignment in ac.subexpressions: to_iterate = chain(to_iterate, ac.subexpressions)
field_reads.update(assignment.rhs.atoms(Field.Access))
if main_assignments: if main_assignments:
for assignment in ac.main_assignments: to_iterate = chain(to_iterate, ac.main_assignments)
for assignment in to_iterate:
if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):
field_reads.update(assignment.rhs.atoms(Field.Access)) field_reads.update(assignment.rhs.atoms(Field.Access))
substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads} substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False) return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
substitute_on_lhs=False, sort_topologically=False)
def transform_rhs(assignment_list, transformation, *args, **kwargs): def transform_rhs(assignment_list, transformation, *args, **kwargs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment