Skip to content
Snippets Groups Projects
Select Git revision
  • afdc27bdb5af8e5f1fe9c274259f743235e43c1f
  • master default protected
  • v2.0-dev protected
  • zikeliml/Task-96-dotExporterForAST
  • zikeliml/124-rework-tutorials
  • fma
  • fhennig/v2.0-deprecations
  • holzer-master-patch-46757
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • gpu_bufferfield_fix
  • hyteg
  • vectorization_sqrt_fix
  • target_dh_refactoring
  • const_fix
  • improved_comm
  • gpu_liveness_opts
  • release/1.3.7 protected
  • release/1.3.6 protected
  • release/2.0.dev0 protected
  • release/1.3.5 protected
  • release/1.3.4 protected
  • release/1.3.3 protected
  • release/1.3.2 protected
  • release/1.3.1 protected
  • release/1.3 protected
  • release/1.2 protected
  • release/1.1.1 protected
  • release/1.1 protected
  • release/1.0.1 protected
  • release/1.0 protected
  • release/0.4.4 protected
  • last/Kerncraft
  • last/OpenCL
  • last/LLVM
  • release/0.4.3 protected
  • release/0.4.2 protected
36 results

sympyextensions.rst

Blame
  • test_subexpression_insertion.py 1.50 KiB
    from pystencils import fields, Assignment, AssignmentCollection
    from pystencils.simp.subexpression_insertion import *
    
    
    def test_subexpression_insertion():
        f, g = fields('f(10), g(10) : [2D]')
        xi = sp.symbols('xi_:10')
        xi_set = set(xi)
    
        subexpressions = [
            Assignment(xi[0], -f(4)),
            Assignment(xi[1], -(f(1) * f(2))),
            Assignment(xi[2], 2.31 * f(5)),
            Assignment(xi[3], 1.8 + f(5) + f(6)),
            Assignment(xi[4], 5.7 + f(6)),
            Assignment(xi[5], (f(4) + f(5))**2),
            Assignment(xi[6], f(3)**2),
            Assignment(xi[7], f(4)),
            Assignment(xi[8], 13),
            Assignment(xi[9], 0),
        ]
    
        assignments = [Assignment(g(i), x) for i, x in enumerate(xi)]
        ac = AssignmentCollection(assignments, subexpressions=subexpressions)
    
        ac_ins = insert_symbol_times_minus_one(ac)
        assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[0]})
    
        ac_ins = insert_constant_multiples(ac)
        assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[0], xi[2]})
    
        ac_ins = insert_constant_additions(ac)
        assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[4]})
    
        ac_ins = insert_squares(ac)
        assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[6]})
    
        ac_ins = insert_aliases(ac)
        assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[7]})
    
        ac_ins = insert_zeros(ac)
        assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[9]})
    
        ac_ins = insert_constants(ac, skip={xi[9]})
        assert (ac_ins.bound_symbols & xi_set) == (xi_set - {xi[8]})