Skip to content
Snippets Groups Projects

Fix AssignmentCollection.{free_symbols,bound_symbols,defined_symbols} for non-Assignments

3 files
+ 47
7
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -3,6 +3,7 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set,
@@ -3,6 +3,7 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set,
import sympy as sp
import sympy as sp
 
import pystencils
from pystencils.assignment import Assignment
from pystencils.assignment import Assignment
from pystencils.simp.simplifications import (
from pystencils.simp.simplifications import (
sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
@@ -100,15 +101,29 @@ class AssignmentCollection:
@@ -100,15 +101,29 @@ class AssignmentCollection:
"""All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
"""All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
free_symbols = set()
free_symbols = set()
for eq in self.all_assignments:
for eq in self.all_assignments:
free_symbols.update(eq.rhs.atoms(sp.Symbol))
if isinstance(eq, Assignment):
 
free_symbols.update(eq.rhs.atoms(sp.Symbol))
 
elif isinstance(eq, pystencils.astnodes.Node):
 
free_symbols.update(eq.undefined_symbols)
 
return free_symbols - self.bound_symbols
return free_symbols - self.bound_symbols
@property
@property
def bound_symbols(self) -> Set[sp.Symbol]:
def bound_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur on the left hand side of a main assignment or a subexpression."""
"""All symbols which occur on the left hand side of a main assignment or a subexpression."""
bound_symbols_set = set([eq.lhs for eq in self.all_assignments])
bound_symbols_set = set(
assert len(bound_symbols_set) == len(self.subexpressions) + len(self.main_assignments), \
[assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)]
 
)
 
 
assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \
"Not in SSA form - same symbol assigned multiple times"
"Not in SSA form - same symbol assigned multiple times"
 
 
bound_symbols_set = bound_symbols_set.union(*[
 
assignment.symbols_defined for assignment in self.all_assignments
 
if isinstance(assignment, pystencils.astnodes.Node)
 
]
 
)
 
return bound_symbols_set
return bound_symbols_set
@property
@property
@@ -124,7 +139,11 @@ class AssignmentCollection:
@@ -124,7 +139,11 @@ class AssignmentCollection:
@property
@property
def defined_symbols(self) -> Set[sp.Symbol]:
def defined_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur as left-hand-sides of one of the main equations"""
"""All symbols which occur as left-hand-sides of one of the main equations"""
return set([assignment.lhs for assignment in self.main_assignments])
return (set(
 
[assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)]
 
).union(*[assignment.symbols_defined for assignment in self.main_assignments if isinstance(
 
assignment, pystencils.astnodes.Node)]
 
))
@property
@property
def operation_count(self):
def operation_count(self):
Loading