Skip to content
Snippets Groups Projects

Introduce default assignment simplifications

Merged Markus Holzer requested to merge holzer/pystencils:Simplifications into master
Compare and
6 files
+ 155
104
Compare changes
  • Side-by-side
  • Inline
Files
6
@@ -3,13 +3,22 @@ from typing import Callable, List, Sequence, Union
from collections import defaultdict
import sympy as sp
from sympy.codegen.rewriting import optims_c99, optimize
from sympy.codegen.rewriting import ReplaceOptim
from pystencils.assignment import Assignment
from pystencils.astnodes import Node
from pystencils.astnodes import Node, SympyAssignment
from pystencils.field import AbstractField, Field
from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
# Evaluates all constant terms
evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
lambda p: p.evalf())
sympy_optimisations = [evaluate_constant_terms] + list(optims_c99)
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
edges = []
@@ -223,3 +232,16 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
f.__name__ = operation.__name__
return f
def apply_sympy_optimisations(assignments):
"""Applies default sympy optimisations. See sympy.codegen.rewriting"""
assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
if hasattr(a, 'lhs')
else a for a in assignments]
assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
for a in chain.from_iterable(assignments_nodes):
a.optimize(sympy_optimisations)
return assignments
Loading