Skip to content
Snippets Groups Projects
Commit a384e104 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Moved subexpression insertion to pystencils

parent 2464ef8e
No related branches found
No related tags found
No related merge requests found
...@@ -5,10 +5,17 @@ from .simplifications import ( ...@@ -5,10 +5,17 @@ from .simplifications import (
add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments, add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments,
subexpression_substitution_in_existing_subexpressions, subexpression_substitution_in_existing_subexpressions,
subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list) subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list)
from .subexpression_insertion import (
insert_aliases, insert_zeros, insert_constants,
insert_constant_additions, insert_constant_multiples,
insert_squares, insert_symbol_times_minus_one)
from .simplificationstrategy import SimplificationStrategy from .simplificationstrategy import SimplificationStrategy
__all__ = ['AssignmentCollection', 'SimplificationStrategy', __all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments', 'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions', 'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_constants', 'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_constants',
'add_subexpressions_for_divisions', 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads'] 'add_subexpressions_for_divisions', 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads',
'insert_aliases', 'insert_zeros', 'insert_constants',
'insert_constant_additions', 'insert_constant_multiples',
'insert_squares', 'insert_symbol_times_minus_one']
import sympy as sp
from pystencils.sympyextensions import is_constant
# Subexpression Insertion
def insert_subexpressions(ac, selection_callback, skip=set()):
"""
Removes a number of subexpressions from an assignment collection by
inserting their right-hand side wherever they occur.
Args:
- selection_callback: Function that is called to qualify subexpressions
for insertion. Should return `True` for any subexpression that is to be
inserted, and `False` otherwise.
- skip: Set of symbols (left-hand sides of subexpressions) that should be
ignored even if qualified by the callback.
"""
i = 0
while i < len(ac.subexpressions):
exp = ac.subexpressions[i]
if exp.lhs not in skip and selection_callback(exp):
ac = ac.new_with_inserted_subexpression(exp.lhs)
else:
i += 1
return ac
def insert_aliases(ac, **kwargs):
"""Inserts subexpressions that are aliases of other symbols,
i.e. their right-hand side is only another symbol."""
return insert_subexpressions(ac, lambda x: isinstance(x.rhs, sp.Symbol), **kwargs)
def insert_zeros(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is zero."""
zero = sp.Integer(0)
return insert_subexpressions(ac, lambda x: x.rhs == zero, **kwargs)
def insert_constants(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is constant,
i.e. contains no symbols."""
return insert_subexpressions(ac, lambda x: is_constant(x.rhs), **kwargs)
def insert_symbol_times_minus_one(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is just a
negation of another symbol."""
def callback(exp):
rhs = exp.rhs
minus_one = sp.Integer(-1)
atoms = rhs.atoms(sp.Symbol)
return len(atoms) == 1 and rhs == minus_one * atoms.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_constant_multiples(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is a constant
multiplied with another symbol."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
numbers = rhs.atoms(sp.Number)
return len(symbols) == 1 and len(numbers) == 1 and \
rhs == numbers.pop() * symbols.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_constant_additions(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is a sum of a
constant and another symbol."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
numbers = rhs.atoms(sp.Number)
return len(symbols) == 1 and len(numbers) == 1 and \
rhs == numbers.pop() + symbols.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_squares(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is another symbol squared."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
return len(symbols) == 1 and rhs == symbols.pop() ** 2
return insert_subexpressions(ac, callback, **kwargs)
def bind_symbols_to_skip(insertion_function, skip):
return lambda ac: insertion_function(ac, skip=skip)
import sympy as sp
from pystencils import fields, Assignment, AssignmentCollection
from pystencils.simp.subexpression_insertion import *
def test_subexpression_insertion():
f, g = fields('f(10), g(10) : [2D]')
xi = sp.symbols('xi_:10')
xi_set = set(xi)
subexpressions = [
Assignment(xi[0], -f(4)),
Assignment(xi[1], -(f(1) * f(2))),
Assignment(xi[2], 2.31 * f(5)),
Assignment(xi[3], 1.8 + f(5) + f(6)),
Assignment(xi[4], 5.7 + f(6)),
Assignment(xi[5], (f(4) + f(5))**2),
Assignment(xi[6], f(3)**2),
Assignment(xi[7], f(4)),
Assignment(xi[8], 13),
Assignment(xi[9], 0),
]
assignments = [Assignment(g(i), x) for i, x in enumerate(xi)]
ac = AssignmentCollection(assignments, subexpressions=subexpressions)
ac_ins = insert_symbol_times_minus_one(ac)
assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[0]})
ac_ins = insert_constant_multiples(ac)
assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[0], xi[2]})
ac_ins = insert_constant_additions(ac)
assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[4]})
ac_ins = insert_squares(ac)
assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[6]})
ac_ins = insert_aliases(ac)
assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[7]})
ac_ins = insert_zeros(ac)
assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[9]})
ac_ins = insert_constants(ac, skip={xi[9]})
assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[8]})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment