show how to create own opts with sympy.codegen.rewriting.optimize in doc
In sympy.codegen.rewriting
there is a function optimize
:
def optimize(expr, optimizations):
""" Apply optimizations to an expression.
Parameters
==========
expr : expression
optimizations : iterable of ``Optimization`` instances
The optimizations will be sorted with respect to ``priority`` (highest first).
Examples
========
>>> from sympy import log, Symbol
>>> from sympy.codegen.rewriting import optims_c99, optimize
>>> x = Symbol('x')
>>> optimize(log(x+3)/log(2) + log(x**2 + 1), optims_c99)
log1p(x**2) + log2(x + 3)
"""
for optim in sorted(optimizations, key=lambda opt: opt.priority, reverse=True):
new_expr = optim(expr)
if optim.cost_function is None:
expr = new_expr
else:
before, after = map(lambda x: optim.cost_function(x), (expr, new_expr))
if before > after:
expr = new_expr
return expr
We should use it in create_kernel
to profit from Sympy's optimizations (currently very few, but some are duplicates from optimizations in pystencils) and to make it easy for users to incorporate their own optimizations into the expressions. create_kernel
could accept an Iterable
of Optimization
s with a default collection of optimizations that pystencils normally uses.
As you can see, it's really easy to implement own optimizations:
def create_expand_pow_optimization(limit):
""" Creates an instance of :class:`ReplaceOptim` for expanding ``Pow``.
The requirements for expansions are that the base needs to be a symbol
and the exponent needs to be an Integer (and be less than or equal to
``limit``).
Parameters
==========
limit : int
The highest power which is expanded into multiplication.
Examples
========
>>> from sympy import Symbol, sin
>>> from sympy.codegen.rewriting import create_expand_pow_optimization
>>> x = Symbol('x')
>>> expand_opt = create_expand_pow_optimization(3)
>>> expand_opt(x**5 + x**3)
x**5 + x*x*x
>>> expand_opt(x**5 + x**3 + sin(x)**3)
x**5 + sin(x)**3 + x*x*x
"""
return ReplaceOptim(
lambda e: e.is_Pow and e.base.is_symbol and e.exp.is_Integer and abs(e.exp) <= limit,
lambda p: (
UnevaluatedExpr(Mul(*([p.base]*+p.exp), evaluate=False)) if p.exp > 0 else
1/UnevaluatedExpr(Mul(*([p.base]*-p.exp), evaluate=False))
))
Edited by Markus Holzer