Skip to content
Snippets Groups Projects

Introduce default assignment simplifications

Merged Markus Holzer requested to merge holzer/pystencils:Simplifications into master
1 file
+ 14
15
Compare changes
  • Side-by-side
  • Inline
@@ -5,20 +5,12 @@ import pystencils as ps
from pystencils import Assignment
from pystencils.astnodes import Block, LoopOverCoordinate, SkipIteration, SympyAssignment
sympy_numeric_version = [int(x, 10) for x in sp.__version__.split('.') if x.isdigit()]
if len(sympy_numeric_version) < 3:
sympy_numeric_version.append(0)
sympy_numeric_version.reverse()
sympy_version = sum(x * (100 ** i) for i, x in enumerate(sympy_numeric_version))
dst = ps.fields('dst(8): double[2D]')
s = sp.symbols('s_:8')
x = sp.symbols('x')
y = sp.symbols('y')
@pytest.mark.skipif(sympy_version < 10501,
reason="Old Sympy Versions behave differently which wont be supported in the near future")
def test_kernel_function():
assignments = [
Assignment(dst[0, 0](0), s[0]),
@@ -44,8 +36,6 @@ def test_skip_iteration():
assert skipped.undefined_symbols == set()
@pytest.mark.skipif(sympy_version < 10501,
reason="Old Sympy Versions behave differently which wont be supported in the near future")
def test_block():
assignments = [
Assignment(dst[0, 0](0), s[0]),
@@ -92,15 +82,24 @@ def test_loop_over_coordinate():
assert loop.step == 2
def test_sympy_assignment():
@pytest.mark.parametrize('default_assignment_simplifications', [False, True])
def test_sympy_assignment(default_assignment_simplifications):
import logging
logging.basicConfig(level=logging.DEBUG)
assignment = SympyAssignment(dst[0, 0](0), sp.log(x + 3) / sp.log(2) + sp.log(x ** 2 + 1))
ast = ps.create_kernel([assignment])
config = ps.CreateKernelConfig(default_assignment_simplifications=default_assignment_simplifications)
ast = ps.create_kernel([assignment], config=config)
code = ps.get_code_str(ast)
assert 'log1p' in code
# constant term is directly evaluated
assert 'log2' not in code
if default_assignment_simplifications:
assert 'log1p' in code
# constant term is directly evaluated
assert 'log2' not in code
else:
# no optimisations will be applied so the optimised version of log will not be in the code
assert 'log1p' not in code
assert 'log2' not in code
assignment.replace(assignment.lhs, dst[0, 0](1))
assignment.replace(assignment.rhs, sp.log(2))
Loading