Skip to content
Snippets Groups Projects

Introduce default assignment simplifications

Merged Markus Holzer requested to merge holzer/pystencils:Simplifications into master
3 files
+ 60
7
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -3,13 +3,22 @@ from typing import Callable, List, Sequence, Union
@@ -3,13 +3,22 @@ from typing import Callable, List, Sequence, Union
from collections import defaultdict
from collections import defaultdict
import sympy as sp
import sympy as sp
 
from sympy.codegen.rewriting import optims_c99, optimize
 
from sympy.codegen.rewriting import ReplaceOptim
from pystencils.assignment import Assignment
from pystencils.assignment import Assignment
from pystencils.astnodes import Node
from pystencils.astnodes import Node, SympyAssignment
from pystencils.field import AbstractField, Field
from pystencils.field import AbstractField, Field
from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
 
# Evaluates all constant terms
 
evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
 
lambda p: p.evalf())
 
 
sympy_optimisations = [evaluate_constant_terms] + list(optims_c99)
 
 
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
edges = []
edges = []
@@ -223,3 +232,16 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
@@ -223,3 +232,16 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
f.__name__ = operation.__name__
f.__name__ = operation.__name__
return f
return f
 
 
 
def apply_sympy_optimisations(assignments):
 
"""Applies default sympy optimisations. See sympy.codegen.rewriting"""
 
 
assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
 
if hasattr(a, 'lhs')
 
else a for a in assignments]
 
assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
 
for a in chain.from_iterable(assignments_nodes):
 
a.optimize(sympy_optimisations)
 
 
return assignments
Loading