Skip to content
Snippets Groups Projects
Commit 605bf8bd authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add sympy optimization: evaluate_constant_terms

parent e9b5aaa0
No related branches found
No related tags found
No related merge requests found
...@@ -6,10 +6,28 @@ See :func:`sympy.codegen.rewriting.optimize`. ...@@ -6,10 +6,28 @@ See :func:`sympy.codegen.rewriting.optimize`.
try: try:
from sympy.codegen.rewriting import optims_c99 from sympy.codegen.rewriting import optims_c99, optimize
from sympy.codegen.rewriting import ReplaceOptim
HAS_REWRITING = True
# Evaluates all constant terms
evaluate_constant_terms = ReplaceOptim(
lambda e: hasattr(e, 'is_constant') and e.is_constant,
lambda p: p.evalf()
)
optims_pystencils_cpu = [evaluate_constant_terms] + list(optims_c99)
optims_pystencils_gpu = [evaluate_constant_terms] + list(optims_c99)
except ImportError: except ImportError:
optims_c99 = [] optims_c99 = []
optims_pystencils_cpu = [] + list(optims_c99) # Evaluates all constant terms
optims_pystencils_gpu = [] + list(optims_c99) evaluate_constant_terms = ReplaceOptim(
lambda e: e.is_constant,
lambda p: p.evalf()
)
optims_pystencils_cpu = [evaluate_constant_terms] + list(optims_c99)
optims_pystencils_gpu = [evaluate_constant_terms] + list(optims_c99)
...@@ -15,3 +15,35 @@ def test_sympy_optimizations(): ...@@ -15,3 +15,35 @@ def test_sympy_optimizations():
ast = pystencils.create_kernel(assignments, target=target) ast = pystencils.create_kernel(assignments, target=target)
code = str(pystencils.show_code(ast)) code = str(pystencils.show_code(ast))
assert 'expm1(' in code assert 'expm1(' in code
def test_evaluate_constant_terms():
for target in ('cpu', 'gpu'):
x, y, z = pystencils.fields('x, y, z: float32[2d]')
# Triggers Sympy's expm1 optimization
assignments = pystencils.AssignmentCollection({
x[0, 0]: -sp.cos(1) + y[0, 0]
})
ast = pystencils.create_kernel(assignments, target=target)
code = str(pystencils.show_code(ast))
assert 'cos(' not in code
print(code)
def test_do_not_evaluate_constant_terms():
optimizations = pystencils.optimizations.optims_pystencils_cpu
optimizations.remove(pystencils.optimizations.evaluate_constant_terms)
for target in ('cpu', 'gpu'):
x, y, z = pystencils.fields('x, y, z: float32[2d]')
assignments = pystencils.AssignmentCollection({
x[0, 0]: -sp.cos(1) + y[0, 0]
})
ast = pystencils.create_kernel(assignments, target=target, optimizations=optimizations)
code = str(pystencils.show_code(ast))
assert 'cos(' in code
print(code)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment