Skip to content
Snippets Groups Projects
Select Git revision
  • 447111eb8e38c0bf15fec66f78e302fc77d6d424
  • master default protected
  • suffa/cumulantfourth_order_correction_with_psm
  • mr_refactor_wfb
  • Sparse
  • WallLaw
  • improved_comm
  • release/1.3.7
  • release/1.3.6
  • release/1.3.5
  • release/1.3.4
  • release/1.3.3
  • release/1.3.2
  • release/1.3.1
  • release/1.3
  • release/1.2
  • release/1.1.1
  • release/1.1
  • release/1.0.1
  • release/1.0
  • release/0.4.4
  • release/0.4.3
  • release/0.4.2
  • release/0.4.1
  • release/0.4.0
  • release/0.3.4
  • release/0.3.3
27 results

__init__.py

Blame
  • test_subexpression_insertion.py 1.50 KiB
    from pystencils import fields, Assignment, AssignmentCollection
    from pystencils.simp.subexpression_insertion import *
    
    
    def test_subexpression_insertion():
        f, g = fields('f(10), g(10) : [2D]')
        xi = sp.symbols('xi_:10')
        xi_set = set(xi)
    
        subexpressions = [
            Assignment(xi[0], -f(4)),
            Assignment(xi[1], -(f(1) * f(2))),
            Assignment(xi[2], 2.31 * f(5)),
            Assignment(xi[3], 1.8 + f(5) + f(6)),
            Assignment(xi[4], 5.7 + f(6)),
            Assignment(xi[5], (f(4) + f(5))**2),
            Assignment(xi[6], f(3)**2),
            Assignment(xi[7], f(4)),
            Assignment(xi[8], 13),
            Assignment(xi[9], 0),
        ]
    
        assignments = [Assignment(g(i), x) for i, x in enumerate(xi)]
        ac = AssignmentCollection(assignments, subexpressions=subexpressions)
    
        ac_ins = insert_symbol_times_minus_one(ac)
        assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[0]})
    
        ac_ins = insert_constant_multiples(ac)
        assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[0], xi[2]})
    
        ac_ins = insert_constant_additions(ac)
        assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[4]})
    
        ac_ins = insert_squares(ac)
        assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[6]})
    
        ac_ins = insert_aliases(ac)
        assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[7]})
    
        ac_ins = insert_zeros(ac)
        assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[9]})
    
        ac_ins = insert_constants(ac, skip={xi[9]})
        assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[8]})