Skip to content
Snippets Groups Projects
Select Git revision
  • 6bc218dfbdb3f9d72ddb1c863dd3d109e7d827db
  • master default protected
  • v2.0-dev protected
  • zikeliml/Task-96-dotExporterForAST
  • zikeliml/124-rework-tutorials
  • fma
  • fhennig/v2.0-deprecations
  • holzer-master-patch-46757
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • gpu_bufferfield_fix
  • hyteg
  • vectorization_sqrt_fix
  • target_dh_refactoring
  • const_fix
  • improved_comm
  • gpu_liveness_opts
  • release/1.3.7 protected
  • release/1.3.6 protected
  • release/2.0.dev0 protected
  • release/1.3.5 protected
  • release/1.3.4 protected
  • release/1.3.3 protected
  • release/1.3.2 protected
  • release/1.3.1 protected
  • release/1.3 protected
  • release/1.2 protected
  • release/1.1.1 protected
  • release/1.1 protected
  • release/1.0.1 protected
  • release/1.0 protected
  • release/0.4.4 protected
  • last/Kerncraft
  • last/OpenCL
  • last/LLVM
  • release/0.4.3 protected
  • release/0.4.2 protected
36 results

test_sympyextensions.py

Blame
  • Frederik Hennig's avatar
    Frederik Hennig authored
      - remove DivFunc
      - Fix various imports
      - Move node_collection, isl, kernel_constraints_check to old
    6bc218df
    History
    test_sympyextensions.py 7.16 KiB
    import sympy
    import numpy as np
    import sympy as sp
    import pystencils
    
    from pystencils.sympyextensions import replace_second_order_products
    from pystencils.sympyextensions import remove_higher_order_terms
    from pystencils.sympyextensions import complete_the_squares_in_exp
    from pystencils.sympyextensions import extract_most_common_factor
    from pystencils.sympyextensions import simplify_by_equality
    from pystencils.sympyextensions import count_operations
    from pystencils.sympyextensions import common_denominator
    from pystencils.sympyextensions import get_symmetric_part
    from pystencils.sympyextensions import scalar_product
    from pystencils.sympyextensions import kronecker_delta
    
    from pystencils import Assignment
    from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt,
                                               insert_fast_divisions, insert_fast_sqrts)
    
    
    def test_utility():
        a = [1, 2]
        b = (2, 3)
    
        a_np = np.array(a)
        b_np = np.array(b)
        assert scalar_product(a, b) == np.dot(a_np, b_np)
    
        a = sympy.Symbol("a")
        b = sympy.Symbol("b")
    
        assert kronecker_delta(a, a, a, b) == 0
        assert kronecker_delta(a, a, a, a) == 1
        assert kronecker_delta(3, 3, 3, 2) == 0
        assert kronecker_delta(2, 2, 2, 2) == 1
        assert kronecker_delta([10] * 100) == 1
        assert kronecker_delta((0, 1), (0, 1)) == 1
    
    
    def test_replace_second_order_products():
        x, y = sympy.symbols('x y')
        expr = 4 * x * y
        expected_expr_positive = 2 * ((x + y) ** 2 - x ** 2 - y ** 2)
        expected_expr_negative = 2 * (-(x - y) ** 2 + x ** 2 + y ** 2)
    
        result = replace_second_order_products(expr, search_symbols=[x, y], positive=True)
        assert result == expected_expr_positive
        assert (result - expected_expr_positive).simplify() == 0
    
        result = replace_second_order_products(expr, search_symbols=[x, y], positive=False)
        assert result == expected_expr_negative
        assert (result - expected_expr_negative).simplify() == 0
    
        result = replace_second_order_products(expr, search_symbols=[x, y], positive=None)
        assert result == expected_expr_positive
    
        a = [Assignment(sympy.symbols('z'), x + y)]
        replace_second_order_products(expr, search_symbols=[x, y], positive=True, replace_mixed=a)
        assert len(a) == 2
    
        assert replace_second_order_products(4 + y, search_symbols=[x, y]) == y + 4
    
    
    def test_remove_higher_order_terms():
        x, y = sympy.symbols('x y')
    
        expr = sympy.Mul(x, y)
    
        result = remove_higher_order_terms(expr, order=1, symbols=[x, y])
        assert result == 0
        result = remove_higher_order_terms(expr, order=2, symbols=[x, y])
        assert result == expr
    
        expr = sympy.Pow(x, 3)
    
        result = remove_higher_order_terms(expr, order=2, symbols=[x, y])
        assert result == 0
        result = remove_higher_order_terms(expr, order=3, symbols=[x, y])
        assert result == expr
    
    
    def test_complete_the_squares_in_exp():
        a, b, c, s, n = sympy.symbols('a b c s n')
        expr = a * s ** 2 + b * s + c
        result = complete_the_squares_in_exp(expr, symbols_to_complete=[s])
        assert result == expr
    
        expr = sympy.exp(a * s ** 2 + b * s + c)
        expected_result = sympy.exp(a*s**2 + c - b**2 / (4*a))
        result = complete_the_squares_in_exp(expr, symbols_to_complete=[s])
        assert result == expected_result
    
    
    def test_extract_most_common_factor():
        x, y = sympy.symbols('x y')
        expr = 1 / (x + y) + 3 / (x + y) + 3 / (x + y)
        most_common_factor = extract_most_common_factor(expr)
    
        assert most_common_factor[0] == 7
        assert sympy.prod(most_common_factor) == expr
    
        expr = 1 / x + 3 / (x + y) + 3 / y
        most_common_factor = extract_most_common_factor(expr)
    
        assert most_common_factor[0] == 3
        assert sympy.prod(most_common_factor) == expr
    
        expr = 1 / x
        most_common_factor = extract_most_common_factor(expr)
    
        assert most_common_factor[0] == 1
        assert sympy.prod(most_common_factor) == expr
        assert most_common_factor[1] == expr
    
    
    def test_count_operations():
        x, y, z = sympy.symbols('x y z')
        expr = 1/x + y * sympy.sqrt(z)
        ops = count_operations(expr, only_type=None)
        assert ops['adds'] == 1
        assert ops['muls'] == 1
        assert ops['divs'] == 1
        assert ops['sqrts'] == 1
    
        expr = 1 / sympy.sqrt(z)
        ops = count_operations(expr, only_type=None)
        assert ops['adds'] == 0
        assert ops['muls'] == 0
        assert ops['divs'] == 1
        assert ops['sqrts'] == 1
    
        expr = sympy.Rel(1 / sympy.sqrt(z), 5)
        ops = count_operations(expr, only_type=None)
        assert ops['adds'] == 0
        assert ops['muls'] == 0
        assert ops['divs'] == 1
        assert ops['sqrts'] == 1
    
        expr = sympy.sqrt(x + y)
        expr = insert_fast_sqrts(expr).atoms(fast_sqrt)
        ops = count_operations(*expr, only_type=None)
        assert ops['fast_sqrts'] == 1
    
        expr = sympy.sqrt(x / y)
        expr = insert_fast_divisions(expr).atoms(fast_division)
        ops = count_operations(*expr, only_type=None)
        assert ops['fast_div'] == 1
    
        expr = pystencils.Assignment(sympy.Symbol('tmp'), 3 / sympy.sqrt(x + y))
        expr = insert_fast_sqrts(expr).atoms(fast_inv_sqrt)
        ops = count_operations(*expr, only_type=None)
        assert ops['fast_inv_sqrts'] == 1
    
        expr = sympy.Piecewise((1.0, x > 0), (0.0, True)) + y * z
        ops = count_operations(expr, only_type=None)
        assert ops['adds'] == 1
    
        expr = sympy.Pow(1/x + y * sympy.sqrt(z), 100)
        ops = count_operations(expr, only_type=None)
        assert ops['adds'] == 1
        assert ops['muls'] == 99
        assert ops['divs'] == 1
        assert ops['sqrts'] == 1
    
        expr = x / y
        ops = count_operations(expr, only_type=None)
        assert ops['divs'] == 1
    
        expr = x + z / y + z
        ops = count_operations(expr, only_type=None)
        assert ops['adds'] == 2
        assert ops['divs'] == 1
    
        expr = sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))
        ops = count_operations(expr, only_type=None)
        assert ops['muls'] == 99
    
        expr = 1 / sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))
        ops = count_operations(expr, only_type=None)
        assert ops['divs'] == 1
        assert ops['muls'] == 99
    
        expr = (y + z) / sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))
        ops = count_operations(expr, only_type=None)
        assert ops['adds'] == 1
        assert ops['divs'] == 1
        assert ops['muls'] == 99
    
    
    def test_common_denominator():
        x = sympy.symbols('x')
        expr = sympy.Rational(1, 2) + x * sympy.Rational(2, 3)
        cm = common_denominator(expr)
        assert cm == 6
    
    
    def test_get_symmetric_part():
        x, y, z = sympy.symbols('x y z')
        expr = x / 9 - y ** 2 / 6 + z ** 2 / 3 + z / 3
        expected_result = x / 9 - y ** 2 / 6 + z ** 2 / 3
        sym_part = get_symmetric_part(expr, sympy.symbols(f'y z'))
    
        assert sym_part == expected_result
    
    
    def test_simplify_by_equality():
        x, y, z = sp.symbols('x, y, z')
        p, q = sp.symbols('p, q')
    
        #   Let x = y + z
        expr = x * p - y * p + z * q
        expr = simplify_by_equality(expr, x, y, z)
        assert expr == z * p + z * q
    
        expr = x * (p - 2 * q) + 2 * q * z
        expr = simplify_by_equality(expr, x, y, z)
        assert expr == x * p - 2 * q * y
    
        expr = x * (y + z) - y * z
        expr = simplify_by_equality(expr, x, y, z)
        assert expr == x*y + z**2
    
        #   Let x = y + 2
        expr = x * p - 2 * p
        expr = simplify_by_equality(expr, x, y, 2)
        assert expr == y * p