Skip to content
Snippets Groups Projects
Select Git revision
  • d58ba8df51f66622e477cde069e3b232e606f65f
  • master default protected
  • v2.0-dev protected
  • zikeliml/Task-96-dotExporterForAST
  • zikeliml/124-rework-tutorials
  • fma
  • fhennig/v2.0-deprecations
  • holzer-master-patch-46757
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • gpu_bufferfield_fix
  • hyteg
  • vectorization_sqrt_fix
  • target_dh_refactoring
  • const_fix
  • improved_comm
  • gpu_liveness_opts
  • release/1.3.7 protected
  • release/1.3.6 protected
  • release/2.0.dev0 protected
  • release/1.3.5 protected
  • release/1.3.4 protected
  • release/1.3.3 protected
  • release/1.3.2 protected
  • release/1.3.1 protected
  • release/1.3 protected
  • release/1.2 protected
  • release/1.1.1 protected
  • release/1.1 protected
  • release/1.0.1 protected
  • release/1.0 protected
  • release/0.4.4 protected
  • last/Kerncraft
  • last/OpenCL
  • last/LLVM
  • release/0.4.3 protected
  • release/0.4.2 protected
36 results

test_simplifications.py

Blame
  • test_simplifications.py 6.65 KiB
    from sys import version_info as vs
    import pytest
    
    import pystencils.config
    import sympy as sp
    import pystencils as ps
    
    from pystencils.simp import subexpression_substitution_in_main_assignments
    from pystencils.simp import add_subexpressions_for_divisions
    from pystencils.simp import add_subexpressions_for_sums
    from pystencils.simp import add_subexpressions_for_field_reads
    from pystencils.simp.simplifications import add_subexpressions_for_constants
    from pystencils import Assignment, AssignmentCollection, fields
    
    a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
    s0, s1, s2, s3 = sp.symbols("s_:4")
    f = sp.symbols("f_:9")
    
    
    def test_subexpression_substitution_in_main_assignments():
        subexpressions = [
            Assignment(s0, 2 * a + 2 * b),
            Assignment(s1, 2 * a + 2 * b + 2 * c),
            Assignment(s2, 2 * a + 2 * b + 2 * c + 2 * d),
            Assignment(s3, 2 * a + 2 * b * c),
            Assignment(x, s1 + s2 + s0 + s3)
        ]
        main = [
            Assignment(f[0], s1 + s2 + s0 + s3),
            Assignment(f[1], s1 + s2 + s0 + s3),
            Assignment(f[2], s1 + s2 + s0 + s3),
            Assignment(f[3], s1 + s2 + s0 + s3),
            Assignment(f[4], s1 + s2 + s0 + s3)
        ]
        ac = AssignmentCollection(main, subexpressions)
        ac = subexpression_substitution_in_main_assignments(ac)
        for i in range(0, len(ac.main_assignments)):
            assert ac.main_assignments[i].rhs == x
    
    
    def test_add_subexpressions_for_divisions():
        subexpressions = [
            Assignment(s0, 2 / a + 2 / b),
            Assignment(s1, 2 / a + 2 / b + 2 / c),
            Assignment(s2, 2 / a + 2 / b + 2 / c + 2 / d),
            Assignment(s3, 2 / a + 2 / b / c),
            Assignment(x, s1 + s2 + s0 + s3)
        ]
        main = [
            Assignment(f[0], s1 + s2 + s0 + s3)
        ]
        ac = AssignmentCollection(main, subexpressions)
        divs_before_optimisation = ac.operation_count["divs"]
        ac = add_subexpressions_for_divisions(ac)
        divs_after_optimisation = ac.operation_count["divs"]
        assert divs_before_optimisation - divs_after_optimisation == 8
        rhs = []
        for i in range(len(ac.subexpressions)):
            rhs.append(ac.subexpressions[i].rhs)
    
        assert 1/a in rhs
        assert 1/b in rhs
        assert 1/c in rhs
        assert 1/d in rhs
    
    
    def test_add_subexpressions_for_constants():
        half = sp.Rational(1,2)
        sqrt_2 = sp.sqrt(2)
        main = [
            Assignment(f[0], half * a + half * b + half * c),
            Assignment(f[1], - half * a - half * b),
            Assignment(f[2], a * sqrt_2 - b * sqrt_2),
            Assignment(f[3], a**2 + b**2)
        ]
        ac = AssignmentCollection(main)
        ac = add_subexpressions_for_constants(ac)
        
        assert len(ac.subexpressions) == 2
        
        half_subexp = None
        sqrt_subexp = None
    
        for asm in ac.subexpressions:
            if asm.rhs == half:
                half_subexp = asm.lhs
            elif asm.rhs == sqrt_2:
                sqrt_subexp = asm.lhs
            else:
                pytest.fail(f"An unexpected subexpression was encountered: {asm}")
                
        assert half_subexp is not None
        assert sqrt_subexp is not None
        
        for asm in ac.main_assignments[:3]:
            assert isinstance(asm.rhs, sp.Mul)
    
        assert any(arg == half_subexp for arg in ac.main_assignments[0].rhs.args)
        assert any(arg == half_subexp for arg in ac.main_assignments[1].rhs.args)
        assert any(arg == sqrt_subexp for arg in ac.main_assignments[2].rhs.args)
    
        #   Do not replace exponents!
        assert ac.main_assignments[3].rhs == a**2 + b**2
    
    
    def test_add_subexpressions_for_sums():
        subexpressions = [
            Assignment(s0, a + b + c + d),
            Assignment(s1, 3 * a * sp.sqrt(x) + 4 * b + c),
            Assignment(s2, 3 * a * sp.sqrt(x) + 4 * b + c),
            Assignment(s3, 3 * a * sp.sqrt(x) + 4 * b + c)
        ]
        main = [
            Assignment(f[0], s1 + s2 + s0 + s3)
        ]
        ac = AssignmentCollection(main, subexpressions)
        ops_before_optimisation = ac.operation_count
        ac = add_subexpressions_for_sums(ac)
        ops_after_optimisation = ac.operation_count
        assert ops_after_optimisation["adds"] == ops_before_optimisation["adds"]
        assert ops_after_optimisation["muls"] < ops_before_optimisation["muls"]
        assert ops_after_optimisation["sqrts"] < ops_before_optimisation["sqrts"]
    
        rhs = []
        for i in range(len(ac.subexpressions)):
            rhs.append(ac.subexpressions[i].rhs)
    
        assert a + b + c + d in rhs
        assert 3 * a * sp.sqrt(x) in rhs
    
    
    def test_add_subexpressions_for_field_reads():
        s, v = fields("s(5), v(5): double[2D]")
        subexpressions = []
        main = [
            Assignment(s[0, 0](0), 3 * v[0, 0](0)),
            Assignment(s[0, 0](1), 10 * v[0, 0](1))
        ]
        ac = AssignmentCollection(main, subexpressions)
        assert len(ac.subexpressions) == 0
        ac = add_subexpressions_for_field_reads(ac)
        assert len(ac.subexpressions) == 2
    
    
    @pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
    @pytest.mark.parametrize('simplification', (True, False))
    @pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason")
    def test_sympy_optimizations(target, simplification):
        if target == ps.Target.GPU:
            pytest.importorskip("pycuda")
        src, dst = ps.fields('src, dst:  float32[2d]')
    
        # Triggers Sympy's expm1 optimization
        # Sympy's expm1 optimization is tedious to use and the behaviour is highly depended on the sympy version. In
        # some cases the exp expression has to be encapsulated in brackets or multiplied with 1 or 1.0
        # for sympy to work properly ...
        assignments = ps.AssignmentCollection({
            src[0, 0]: 1.0 * (sp.exp(dst[0, 0]) - 1)
        })
    
        config = pystencils.config.CreateKernelConfig(target=target, default_assignment_simplifications=simplification)
        ast = ps.create_kernel(assignments, config=config)
    
        code = ps.get_code_str(ast)
        if simplification:
            assert 'expm1(' in code
        else:
            assert 'expm1(' not in code
    
    
    @pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
    @pytest.mark.parametrize('simplification', (True, False))
    @pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason")
    def test_evaluate_constant_terms(target, simplification):
        if target == ps.Target.GPU:
            pytest.importorskip("pycuda")
        src, dst = ps.fields('src, dst:  float32[2d]')
    
        # Triggers Sympy's cos optimization
        assignments = ps.AssignmentCollection({
            src[0, 0]: -sp.cos(1) + dst[0, 0]
        })
    
        config = pystencils.config.CreateKernelConfig(target=target, default_assignment_simplifications=simplification)
        ast = ps.create_kernel(assignments, config=config)
        code = ps.get_code_str(ast)
        if simplification:
            assert 'cos(' not in code
        else:
            assert 'cos(' in code
        print(code)